Source code for framework.infrastructure.classification_node

"""
ALS Expert Agent - Classification Node

Task classification and capability selection with sophisticated analysis.
Combines LangGraph infrastructure with core classification logic.

Analyzes user queries to determine required capabilities and data dependencies.
Convention-based LangGraph-native implementation with built-in error handling and retry policies.
"""

from __future__ import annotations
import asyncio

from typing import Optional, TYPE_CHECKING, Dict, Any, List


from framework.base.decorators import infrastructure_node
from framework.state import AgentState
from framework.state.state import create_status_update
from framework.registry import get_registry
from framework.base import BaseCapability, ClassifierExample, CapabilityMatch
from framework.models import get_chat_completion
from framework.prompts.loader import get_framework_prompts
from configs.config import get_full_configuration, get_model_config
from configs.logger import get_logger
from configs.streaming import get_streamer
from framework.base.errors import ErrorClassification, ErrorSeverity
from framework.base.nodes import BaseInfrastructureNode


# Use colored logger for classifier
logger = get_logger("framework", "classifier")


[docs] @infrastructure_node class ClassificationNode(BaseInfrastructureNode): """Convention-based classification node with sophisticated capability selection logic. Analyzes user tasks and selects appropriate capabilities using parallel LLM-based classification with few-shot examples. Handles both initial classification and reclassification scenarios. Uses LangGraph's sophisticated state merging with built-in error handling and retry policies optimized for LLM-based classification operations. """ # Loaded through registry configuration name = "classifier" description = "Task Classification and Capability Selection"
[docs] @staticmethod def classify_error(exc: Exception, context: Dict[str, Any]) -> ErrorClassification: """Built-in error classification for classifier operations. :param exc: Exception that occurred :param context: Error context information :return: Classification with severity and retry guidance """ # Retry LLM timeouts and network errors if hasattr(exc, '__class__') and 'timeout' in exc.__class__.__name__.lower(): return ErrorClassification( severity=ErrorSeverity.RETRIABLE, user_message="Classification service temporarily unavailable, retrying...", metadata={"technical_details": f"LLM timeout: {str(exc)}"} ) if isinstance(exc, (ConnectionError, TimeoutError)): return ErrorClassification( severity=ErrorSeverity.RETRIABLE, user_message="Network connectivity issues during classification, retrying...", metadata={"technical_details": f"Network error: {str(exc)}"} ) # Don't retry validation errors (data/logic issues) if isinstance(exc, (ValueError, TypeError)): return ErrorClassification( severity=ErrorSeverity.CRITICAL, user_message="Task classification configuration error", metadata={ "technical_details": f"Validation error: {str(exc)}", "safety_abort_reason": "Classification system misconfiguration detected" } ) # Don't retry import/module errors (missing dependencies or path issues) # Check both the exception itself and any chained exceptions def is_import_error(e): if isinstance(e, (ImportError, ModuleNotFoundError, NameError)): return True # Check chained exceptions (from "raise X from Y") if hasattr(e, '__cause__') and e.__cause__: return isinstance(e.__cause__, (ImportError, ModuleNotFoundError, NameError)) return False if is_import_error(exc): return ErrorClassification( severity=ErrorSeverity.CRITICAL, user_message="Task classification dependencies not available", metadata={ "technical_details": f"Import error: {str(exc)}", "safety_abort_reason": "Required classification dependencies missing" } ) # Default: CRITICAL for unknown errors (fail safe principle) # Only explicitly known errors should be RETRIABLE return ErrorClassification( severity=ErrorSeverity.CRITICAL, user_message=f"Unknown classification error: {str(exc)}", metadata={ "technical_details": f"Error type: {type(exc).__name__}, Details: {str(exc)}", "safety_abort_reason": "Unhandled classification system error" } )
[docs] @staticmethod def get_retry_policy() -> Dict[str, Any]: """Custom retry policy for LLM-based classification operations. Classification uses parallel LLM calls for capability selection and can be flaky due to: - Multiple concurrent LLM requests - Network timeouts to LLM services - LLM provider rate limiting - Classification model variability Use more attempts with moderate delays for better reliability. """ return { "max_attempts": 4, # More attempts for LLM classification "delay_seconds": 1.0, # Moderate delay for parallel LLM calls "backoff_factor": 1.8 # Moderate backoff to handle rate limiting }
[docs] @staticmethod async def execute( state: AgentState, **kwargs ) -> Dict[str, Any]: """Main classification logic with sophisticated capability selection and reclassification handling. Analyzes user tasks and selects appropriate capabilities using parallel LLM-based classification. Handles both initial classification and reclassification scenarios with state preservation. :param state: Current agent state :type state: AgentState :param kwargs: Additional LangGraph parameters :return: Dictionary of state updates for LangGraph :rtype: Dict[str, Any] """ # Get the current task from state current_task = state.get("task_current_task") if not current_task: logger.error("No current task found in state") return { "control_needs_reclassification": True, "control_reclassification_reason": "No current task found" } # Define streaming helper here for step awareness streamer = get_streamer("framework", "classifier", state) # Get previous failure context (may be None for initial classification) previous_failure = state.get('control_reclassification_reason') reclassification_count = state.get('control_reclassification_count', 0) if previous_failure: streamer.status(f"Reclassifying task (attempt {reclassification_count + 1})...") logger.info(f"Reclassifying task (attempt {reclassification_count + 1})...") logger.warning(f"Previous failure reason: {previous_failure}") else: streamer.status("Analyzing task requirements...") logger.info("Analyzing task requirements...") logger.info(f"Classifying task: {current_task}") # Get available capabilities from capability registry registry = get_registry() available_capabilities = registry.get_all_capabilities() logger.debug(f"Available capabilities: {len(available_capabilities)}") # Run capability selection using the task analyzer (core business logic) active_capabilities = await select_capabilities( task=current_task, # Updated parameter name available_capabilities=available_capabilities, state=state, logger=logger, previous_failure=previous_failure # Pass failure context for reclassification ) logger.key_info(f"Classification completed with {len(active_capabilities)} active capabilities") logger.debug(f"Active capabilities: {active_capabilities}") streamer.status("Task classification complete") # Return proper LangGraph state updates that merge instead of overwriting # Use StateManager methods for cleaner state updates # Use direct state updates instead of utility functions # Direct planning state update planning_fields = { "planning_active_capabilities": active_capabilities, "planning_execution_plan": None, "planning_current_step_index": 0 } # Always increment classification counter and clear reclassification flags control_flow_update = { "control_reclassification_count": reclassification_count + 1, "control_needs_reclassification": False, "control_reclassification_reason": None } # Add status event using LangGraph's add reducer status_event = create_status_update( message=f"Classification completed with {len(active_capabilities)} capabilities", progress=1.0, complete=True, node="classifier", capabilities_selected=len(active_capabilities), capability_names=active_capabilities, # Already capability names now reclassification=bool(previous_failure), reclassification_count=reclassification_count + 1 ) logger.info("Classification completed") # Merge all updates - LangGraph will handle this properly return {**planning_fields, **control_flow_update, **status_event}
# ==================================================== # Classification helper functions # ====================================================
[docs] async def select_capabilities( task: str, available_capabilities: List[BaseCapability], state: AgentState, logger, previous_failure: Optional[str] = None ) -> List[str]: # Return capability names instead of instances """Select capabilities needed for the task by using classification. :param task: Task description for analysis :type task: str :param available_capabilities: Available capabilities to choose from :type available_capabilities: List[BaseCapability] :param state: Current agent state :type state: AgentState :param logger: Logger instance :return: List of capability names needed for the task :rtype: List[str] """ # Get registry to access always-active capability names registry = get_registry() always_active_names = registry.get_always_active_capability_names() active_capabilities: List[str] = [] # Store capability names instead of instances # Step 1: Add always-active capabilities from registry configuration for capability in available_capabilities: if capability.name in always_active_names: active_capabilities.append(capability.name) # Store name instead of instance # Step 2: Classify remaining capabilities (those not marked as always_active) remaining_capabilities = [cap for cap in available_capabilities if cap.name not in always_active_names] # Classify each remaining capability for capability in remaining_capabilities: is_required = await _classify_capability(capability, task, state, logger, previous_failure) if is_required: active_capabilities.append(capability.name) # Store name instead of instance logger.info(f"{len(active_capabilities)} capabilities required: {active_capabilities}") return active_capabilities
async def _classify_capability(capability: BaseCapability, task: str, state: AgentState, logger, previous_failure: Optional[str] = None) -> bool: """Classify a single capability to determine if it's needed. :param capability: The capability to analyze :type capability: BaseCapability :param task: Task description for analysis :type task: str :param state: Current agent state :type state: AgentState :param logger: Logger instance :return: True if capability is required, False otherwise :rtype: bool """ # Skip capabilities without classifiers - handle errors during classifier loading try: classifier = capability.classifier_guide if not classifier: logger.warning(f"No classifier found for capability '{capability.name}' - skipping") return False except Exception as e: logger.error(f"Error loading classifier for capability '{capability.name}': {e}") # For import errors, skip this capability instead of failing entire classification if isinstance(e, (ImportError, ModuleNotFoundError, NameError)): logger.warning(f"Skipping capability '{capability.name}' due to import error: {e}") return False # For other errors, re-raise with capability context for better error reporting raise Exception(f"Capability '{capability.name}' classifier failed: {e}") from e capability_instructions = classifier.instructions examples_string = ClassifierExample.format_examples_for_prompt(classifier.examples) # Get classification prompt directly prompt_provider = get_framework_prompts() classification_builder = prompt_provider.get_classification_prompt_builder() system_prompt = classification_builder.get_system_instructions( capability_instructions=capability_instructions, classifier_examples=examples_string, context=None, previous_failure=previous_failure ) message = f"{system_prompt}\n\nUser request:\n{task}" logger.debug(f"\n\nTask Analyzer System Prompt for capability '{capability.name}':\n{message}\n\n") try: response_data = await asyncio.to_thread( get_chat_completion, model_config=get_model_config("framework", "classifier"), message=message, output_model=CapabilityMatch, ) if isinstance(response_data, CapabilityMatch): single_output = response_data else: logger.error(f"Classification call for '{capability.name}' did not return a CapabilityMatch. Got: {type(response_data)}") single_output = CapabilityMatch(is_match=False) logger.info(f" >>> Capability '{capability.name}' >>> {single_output.is_match}") return single_output.is_match except Exception as e: logger.error(f"Error in capability classification for '{capability.name}': {e}") return False