"""Dynamic MCP tool loader — connects to user-configured MCP servers and returns their tools.""" import time import logging import urllib.request import urllib.parse import json import boto3 from mcp.client.streamable_http import streamablehttp_client from strands.tools.mcp.mcp_client import MCPClient logger = logging.getLogger(__name__) # Token cache: {f"{actor_id}:{conn_name}": {"token": str, "expires_at": float}} _token_cache: dict = {} def _get_ssm_value(param_name: str) -> str: ssm = boto3.client('ssm', region_name='us-east-1') return ssm.get_parameter(Name=param_name, WithDecryption=True)['Parameter']['Value'] def _get_oauth_token(conn: dict, actor_id: str) -> str: """Fetch OAuth token via client_credentials grant, with caching.""" cache_key = f"{actor_id}:{conn['name']}" cached = _token_cache.get(cache_key) if cached and cached['expires_at'] > time.time(): return cached['token'] client_secret = _get_ssm_value(conn['client_secret_ssm']) data = urllib.parse.urlencode({ 'grant_type': 'client_credentials', 'client_id': conn['client_id'], 'client_secret': client_secret, 'scope': conn.get('scope', ''), }).encode() req = urllib.request.Request(conn['cognito_token_url'], data=data, headers={'Content-Type': 'application/x-www-form-urlencoded'}) with urllib.request.urlopen(req, timeout=10) as resp: body = json.loads(resp.read()) token = body['access_token'] expires_in = body.get('expires_in', 3600) _token_cache[cache_key] = {'token': token, 'expires_at': time.time() + expires_in - 30} return token def _get_m2m_token(conn: dict, actor_id: str) -> str: """Fetch OAuth token for oauth2_m2m (secret stored directly in record).""" cache_key = f"{actor_id}:{conn['name']}" cached = _token_cache.get(cache_key) if cached and cached['expires_at'] > time.time() + 60: return cached['token'] data = urllib.parse.urlencode({ 'grant_type': 'client_credentials', 'client_id': conn['client_id'], 'client_secret': conn['client_secret'], 'scope': conn.get('scopes', conn.get('scope', '')), }).encode() req = urllib.request.Request(conn['token_url'], data=data, headers={'Content-Type': 'application/x-www-form-urlencoded'}) with urllib.request.urlopen(req, timeout=10) as resp: body = json.loads(resp.read()) token = body['access_token'] expires_in = body.get('expires_in', 3600) _token_cache[cache_key] = {'token': token, 'expires_at': time.time() + expires_in} return token def _resolve_auth_headers(conn: dict, actor_id: str) -> dict: """Resolve auth headers for a connection.""" auth_type = conn.get('auth_type', 'none') if auth_type == 'oauth_client_credentials': token = _get_oauth_token(conn, actor_id) return {'Authorization': f'Bearer {token}'} elif auth_type == 'oauth2_m2m': token = _get_m2m_token(conn, actor_id) return {'Authorization': f'Bearer {token}'} elif auth_type == 'bearer': token = _get_ssm_value(conn['token_ssm']) return {'Authorization': f'Bearer {token}'} return {} def invalidate_token(conn_name: str, actor_id: str): """Invalidate cached token for a connection (call on auth failure).""" _token_cache.pop(f"{actor_id}:{conn_name}", None) def load_mcp_tools(mcp_connections: list, actor_id: str) -> tuple[list, list]: """Connect to each enabled MCP server and return (tools_list, clients_to_close). Returns: Tuple of (list of MCPClient instances to pass to Agent, list of same clients to close later) """ clients = [] for conn in mcp_connections: if not conn.get('enabled', True): continue name = conn.get('name', 'unknown') try: headers = _resolve_auth_headers(conn, actor_id) url = conn['url'] client = MCPClient(lambda u=url, h=headers: streamablehttp_client(u, headers=h)) client.start() clients.append(client) logger.info(f'[mcp_loader] Connected to MCP server: {name}') except Exception as e: logger.error(f'[mcp_loader] Failed to connect to MCP server "{name}": {e}') return clients, clients def close_mcp_clients(clients: list): """Close all MCP clients.""" for client in clients: try: client.stop() except Exception as e: logger.error(f'[mcp_loader] Error closing MCP client: {e}')