Source code for framework.models.factory

"""AI Model Factory for Structured Generation.

Creates and configures LLM model instances for use with PydanticAI agents and structured
generation workflows. This factory handles the complexity of provider-specific initialization,
credential management, HTTP client configuration, and proxy setup across multiple AI providers.

The factory supports enterprise-grade features including connection pooling, timeout management,
and automatic HTTP proxy detection through environment variables. Each provider has specific
requirements for API keys, base URLs, and model identifiers that are validated and enforced.

.. note::
   Model instances created here are optimized for structured generation with PydanticAI.
   For direct chat completions without structured outputs, consider using
   :func:`~completion.get_chat_completion` instead.

.. seealso::
   :func:`get_model` : Main factory function for creating model instances
   :func:`~completion.get_chat_completion` : Direct chat completion interface
   :mod:`configs.config` : Provider configuration and credential management
"""

import logging
import os
from typing import Optional, Union
from urllib.parse import urlparse
import httpx
from pydantic_ai.models.openai import OpenAIModel
from pydantic_ai.providers.openai import OpenAIProvider
from pydantic_ai.models.gemini import GeminiModel
from pydantic_ai.providers.google_gla import GoogleGLAProvider
from pydantic_ai.models.anthropic import AnthropicModel
from pydantic_ai.providers.anthropic import AnthropicProvider
import openai
from google import genai

from configs.config import get_provider_config


def _create_openai_compatible_model(
    model_id: str,
    api_key: str,
    base_url: Optional[str],
    timeout_arg_from_get_model: Optional[float],
    shared_http_client: Optional[httpx.AsyncClient] = None
) -> OpenAIModel:
    """Create an OpenAI-compatible model instance with proper client configuration.
    
    This internal helper function handles the creation of OpenAI-compatible models
    for providers that use the OpenAI API format (OpenAI, Ollama, CBORG). It manages
    HTTP client configuration, timeout settings, and base URL handling with proper
    fallback behavior.
    
    :param model_id: Model identifier as recognized by the provider
    :type model_id: str
    :param api_key: API authentication key for the provider
    :type api_key: str
    :param base_url: Provider's API base URL, None for OpenAI's default endpoint
    :type base_url: str, optional
    :param timeout_arg_from_get_model: Request timeout in seconds, defaults to 60.0
    :type timeout_arg_from_get_model: float, optional
    :param shared_http_client: Pre-configured HTTP client with proxy/timeout settings
    :type shared_http_client: httpx.AsyncClient, optional
    :return: Configured OpenAI model instance ready for structured generation
    :rtype: OpenAIModel
    
    .. note::
       When a shared HTTP client is provided, timeout configuration is managed
       by the client rather than the OpenAI client constructor.
    
    .. seealso::
       :func:`get_model` : Main factory function that calls this helper
       :class:`pydantic_ai.models.openai.OpenAIModel` : PydanticAI OpenAI model wrapper
    """
    
    openai_client_instance: openai.AsyncOpenAI
    if shared_http_client:
        client_args = {
            "api_key": api_key,
            "http_client": shared_http_client
        }
        if base_url: # Pass base_url if provided
            client_args["base_url"] = base_url
        openai_client_instance = openai.AsyncOpenAI(**client_args)
    else:
        # No shared client.
        effective_timeout_for_openai = timeout_arg_from_get_model if timeout_arg_from_get_model is not None else 60.0
        client_args = {
            "api_key": api_key,
            "timeout": effective_timeout_for_openai
        }
        if base_url: # Pass base_url if provided
            client_args["base_url"] = base_url
        openai_client_instance = openai.AsyncOpenAI(**client_args)

    model = OpenAIModel(
        model_name=model_id,
        provider=OpenAIProvider(openai_client=openai_client_instance),
    )
    # Storing original model_id for clarity
    model.model_id = model_id 
    return model


logger = logging.getLogger(__name__)


def _validate_proxy_url(proxy_url: str) -> bool:
    """Validate HTTP proxy URL format and accessibility.
    
    Performs basic validation of proxy URL format to ensure it follows
    standard HTTP/HTTPS proxy URL patterns. This helps catch common
    configuration errors early and provides clear feedback.
    
    :param proxy_url: Proxy URL to validate
    :type proxy_url: str
    :return: True if proxy URL appears valid, False otherwise
    :rtype: bool
    """
    if not proxy_url:
        return False
    
    try:
        parsed = urlparse(proxy_url)
        # Check for valid scheme and netloc (host:port)
        if parsed.scheme not in ('http', 'https'):
            return False
        if not parsed.netloc:
            return False
        return True
    except Exception:
        return False


def _get_ollama_fallback_urls(base_url: str) -> list[str]:
    """Generate fallback URLs for Ollama based on the current base URL.
    
    This helper function generates appropriate fallback URLs to handle
    common development scenarios where the execution context (container vs local)
    doesn't match the configured Ollama URL.
    
    :param base_url: Current configured Ollama base URL
    :type base_url: str
    :return: List of fallback URLs to try in order
    :rtype: list[str]
    
    .. note::
       Fallback URLs are generated based on common patterns:
       - host.containers.internal -> localhost (container to local)
       - localhost -> host.containers.internal (local to container)
       - Generic fallbacks for other scenarios
    """
    fallback_urls = []
    
    if "host.containers.internal" in base_url:
        # Running in container but Ollama might be on localhost
        fallback_urls = [
            base_url.replace("host.containers.internal", "localhost"),
            "http://localhost:11434"
        ]
    elif "localhost" in base_url:
        # Running locally but Ollama might be in container context
        fallback_urls = [
            base_url.replace("localhost", "host.containers.internal"),
            "http://host.containers.internal:11434"
        ]
    else:
        # Generic fallbacks for other scenarios
        fallback_urls = [
            "http://localhost:11434",
            "http://host.containers.internal:11434"
        ]
    
    return fallback_urls


def _test_ollama_connection(base_url: str) -> bool:
    """Test if Ollama is accessible at the given URL.
    
    Performs a simple health check by attempting to connect to Ollama
    and calling the list models endpoint.
    
    :param base_url: Ollama base URL to test
    :type base_url: str
    :return: True if connection successful, False otherwise
    :rtype: bool
    """
    try:
        # Test with a simple synchronous request to avoid async complications
        import requests
        # Convert to OpenAI-compatible endpoint for testing
        test_url = base_url.rstrip('/') + '/v1/models'
        response = requests.get(test_url, timeout=2)
        return response.status_code == 200
    except Exception:
        return False


def _create_ollama_model_with_fallback(
    model_id: str,
    base_url: str,
    provider_config: dict,
    timeout: Optional[float],
    async_http_client: Optional[httpx.AsyncClient]
) -> OpenAIModel:
    """Create Ollama model with graceful fallback for development workflows.
    
    This function attempts to connect to Ollama at the configured URL first,
    then tries common fallback URLs if the initial connection fails. This
    handles the common development scenario where execution context (container
    vs local) doesn't match the configured Ollama endpoint.
    
    :param model_id: Ollama model identifier
    :type model_id: str
    :param base_url: Primary Ollama base URL from configuration
    :type base_url: str
    :param provider_config: Provider configuration dictionary
    :type provider_config: dict
    :param timeout: Request timeout in seconds
    :type timeout: float, optional
    :param async_http_client: Pre-configured HTTP client
    :type async_http_client: httpx.AsyncClient, optional
    :return: Configured OpenAI-compatible model for Ollama
    :rtype: OpenAIModel
    :raises ValueError: If no working Ollama endpoint is found
    """
    effective_base_url = base_url
    if not base_url.endswith('/v1'):
        effective_base_url = base_url.rstrip('/') + '/v1'
    
    used_fallback = False
    
    # Test primary URL first
    if _test_ollama_connection(base_url):
        logger.debug(f"Successfully connected to Ollama at {base_url}")
    else:
        logger.debug(f"Failed to connect to Ollama at {base_url}")
        
        # Try fallback URLs
        fallback_urls = _get_ollama_fallback_urls(base_url)
        working_url = None
        
        for fallback_url in fallback_urls:
            logger.debug(f"Attempting fallback connection to Ollama at {fallback_url}")
            if _test_ollama_connection(fallback_url):
                working_url = fallback_url
                used_fallback = True
                logger.warning(
                    f"⚠️  Ollama connection fallback: configured URL '{base_url}' failed, "
                    f"using fallback '{fallback_url}'. Consider updating your configuration "
                    f"for your current execution environment."
                )
                break
        
        if working_url:
            effective_base_url = working_url
            if not working_url.endswith('/v1'):
                effective_base_url = working_url.rstrip('/') + '/v1'
        else:
            # All connection attempts failed
            raise ValueError(
                f"Failed to connect to Ollama at configured URL '{base_url}' "
                f"and all fallback URLs {fallback_urls}. Please ensure Ollama is running "
                f"and accessible, or update your configuration."
            )
    
    return _create_openai_compatible_model(
        model_id=model_id,
        api_key=provider_config.get("api_key"),
        base_url=effective_base_url,
        timeout_arg_from_get_model=timeout,
        shared_http_client=async_http_client
    )


[docs] def get_model( provider: Optional[str] = None, model_config: Optional[dict] = None, model_id: Optional[str] = None, api_key: Optional[str] = None, base_url: Optional[str] = None, timeout: Optional[float] = None, max_tokens: int = 100000, ) -> Union[OpenAIModel, AnthropicModel, GeminiModel]: """Create a configured LLM model instance for structured generation with PydanticAI. This factory function creates and configures LLM model instances optimized for structured generation workflows using PydanticAI agents. It handles provider-specific initialization, credential validation, HTTP client configuration, and proxy setup automatically based on environment variables and configuration files. The function supports flexible configuration through multiple approaches: - Direct parameter specification for programmatic use - Model configuration dictionaries from YAML files - Automatic credential loading from configuration system - Environment-based HTTP proxy detection and configuration Provider-specific behavior: - **Anthropic**: Requires API key and model ID, supports HTTP proxy - **Google**: Requires API key and model ID, supports HTTP proxy - **OpenAI**: Requires API key and model ID, supports HTTP proxy and custom base URLs - **Ollama**: Requires model ID and base URL, no API key needed, no proxy support - **CBORG**: Requires API key, model ID, and base URL, supports HTTP proxy :param provider: AI provider name ('anthropic', 'google', 'openai', 'ollama', 'cborg') :type provider: str, optional :param model_config: Configuration dictionary with provider, model_id, and other settings :type model_config: dict, optional :param model_id: Specific model identifier recognized by the provider :type model_id: str, optional :param api_key: API authentication key, auto-loaded from config if not provided :type api_key: str, optional :param base_url: Custom API endpoint URL, required for Ollama and CBORG :type base_url: str, optional :param timeout: Request timeout in seconds, defaults to provider configuration :type timeout: float, optional :param max_tokens: Maximum tokens for generation, defaults to 100000 :type max_tokens: int :raises ValueError: If required provider, model_id, api_key, or base_url are missing :raises ValueError: If provider is not supported :return: Configured model instance ready for PydanticAI agent integration :rtype: Union[OpenAIModel, AnthropicModel, GeminiModel] .. note:: HTTP proxy configuration is automatically detected from the HTTP_PROXY environment variable for supported providers. Timeout and connection pooling are managed through shared HTTP clients when proxies are enabled. .. warning:: API keys and base URLs are validated before model creation. Ensure proper configuration is available through the config system or direct parameter specification. Examples: Basic model creation with direct parameters:: >>> from framework.models import get_model >>> model = get_model( ... provider="anthropic", ... model_id="claude-3-sonnet-20240229", ... api_key="your-api-key" ... ) >>> # Use with PydanticAI Agent >>> agent = Agent(model=model, output_type=YourModel) Using configuration dictionary from YAML:: >>> model_config = { ... "provider": "cborg", ... "model_id": "anthropic/claude-sonnet", ... "max_tokens": 4096, ... "timeout": 30.0 ... } >>> model = get_model(model_config=model_config) Ollama local model setup:: >>> model = get_model( ... provider="ollama", ... model_id="llama3.1:8b", ... base_url="http://localhost:11434" ... ) .. seealso:: :func:`~completion.get_chat_completion` : Direct chat completion without structured output :func:`configs.config.get_provider_config` : Provider configuration loading :class:`pydantic_ai.Agent` : PydanticAI agent that uses these models :doc:`/developer-guides/01_understanding-the-framework/02_convention-over-configuration` : Complete model setup guide """ if model_config: provider = model_config.get("provider", provider) model_id = model_config.get("model_id", model_id) max_tokens = model_config.get("max_tokens", max_tokens) timeout = model_config.get("timeout", timeout) if not provider: raise ValueError("Provider must be specified either directly or via model_config") provider_config = get_provider_config(provider) api_key = provider_config.get("api_key", api_key) base_url = provider_config.get("base_url", base_url) timeout = provider_config.get("timeout", timeout) # Define provider requirements provider_requirements = { "google": {"model_id": True, "api_key": True, "base_url": False, "use_proxy": True}, "anthropic": {"model_id": True, "api_key": True, "base_url": False, "use_proxy": True}, "openai": {"model_id": True, "api_key": True, "base_url": False, "use_proxy": True}, "ollama": {"model_id": True, "api_key": False, "base_url": True, "use_proxy": False}, "cborg": {"model_id": True, "api_key": True, "base_url": True, "use_proxy": True}, } if provider not in provider_requirements: raise ValueError(f"Invalid provider: {provider}. Must be 'anthropic', 'cborg', 'google', 'ollama', or 'openai'.") requirements = provider_requirements[provider] # Common validation if requirements["model_id"] and not model_id: raise ValueError(f"Model ID for {provider} not provided.") if requirements["api_key"] and not api_key: raise ValueError(f"No API key provided for {provider}.") if requirements["base_url"] and not base_url: raise ValueError(f"No base URL provided for {provider}.") async_http_client: Optional[httpx.AsyncClient] = None # HTTP proxy is machine dependent and set up through environment variable proxy_url = os.getenv("HTTP_PROXY") should_use_proxy = False if requirements["use_proxy"] and proxy_url: if _validate_proxy_url(proxy_url): should_use_proxy = True else: logger.warning(f"Invalid HTTP_PROXY URL format '{proxy_url}', ignoring proxy configuration") # Create a custom client if a proxy is set (and should be used) or a specific timeout is requested if should_use_proxy or timeout is not None: client_params = {} if should_use_proxy: client_params["proxy"] = proxy_url if timeout is not None: client_params["timeout"] = timeout async_http_client = httpx.AsyncClient(**client_params) # Provider-specific implementation (validation already done above) if provider == "google": google_provider = GoogleGLAProvider( api_key=api_key, http_client=async_http_client ) return GeminiModel(model_name=model_id, provider=google_provider) elif provider == "anthropic": anthropic_provider = AnthropicProvider( api_key=api_key, http_client=async_http_client ) return AnthropicModel( model_name=model_id, provider=anthropic_provider, ) elif provider == "openai": return _create_openai_compatible_model( model_id=model_id, api_key=api_key, base_url=base_url, timeout_arg_from_get_model=timeout, shared_http_client=async_http_client ) elif provider == "ollama": return _create_ollama_model_with_fallback( model_id=model_id, base_url=base_url, provider_config=provider_config, timeout=timeout, async_http_client=async_http_client ) elif provider == "cborg": return _create_openai_compatible_model( model_id=model_id, api_key=api_key, base_url=base_url, timeout_arg_from_get_model=timeout, shared_http_client=async_http_client )