Files
daniel 4ca5fee2c0 refactor: move factcloud from hardcoded SSM to per-user DynamoDB oauth2_m2m connection
- Add oauth2_m2m auth type to mcp_loader.py (client_secret in record, not SSM)
- Remove _get_factcloud_token(), FACTCLOUD_* config, factcloud_clients from main.py
- Seed Daniel's factcloud connection into enrolled_services.mcp_connections
- factcloud now loaded dynamically via mcp_loader at session start
2026-05-16 09:49:28 -05:00

123 lines
4.5 KiB
Python

"""Dynamic MCP tool loader — connects to user-configured MCP servers and returns their tools."""
import time
import logging
import urllib.request
import urllib.parse
import json
import boto3
from mcp.client.streamable_http import streamablehttp_client
from strands.tools.mcp.mcp_client import MCPClient
logger = logging.getLogger(__name__)
# Token cache: {f"{actor_id}:{conn_name}": {"token": str, "expires_at": float}}
_token_cache: dict = {}
def _get_ssm_value(param_name: str) -> str:
ssm = boto3.client('ssm', region_name='us-east-1')
return ssm.get_parameter(Name=param_name, WithDecryption=True)['Parameter']['Value']
def _get_oauth_token(conn: dict, actor_id: str) -> str:
"""Fetch OAuth token via client_credentials grant, with caching."""
cache_key = f"{actor_id}:{conn['name']}"
cached = _token_cache.get(cache_key)
if cached and cached['expires_at'] > time.time():
return cached['token']
client_secret = _get_ssm_value(conn['client_secret_ssm'])
data = urllib.parse.urlencode({
'grant_type': 'client_credentials',
'client_id': conn['client_id'],
'client_secret': client_secret,
'scope': conn.get('scope', ''),
}).encode()
req = urllib.request.Request(conn['cognito_token_url'], data=data,
headers={'Content-Type': 'application/x-www-form-urlencoded'})
with urllib.request.urlopen(req, timeout=10) as resp:
body = json.loads(resp.read())
token = body['access_token']
expires_in = body.get('expires_in', 3600)
_token_cache[cache_key] = {'token': token, 'expires_at': time.time() + expires_in - 30}
return token
def _get_m2m_token(conn: dict, actor_id: str) -> str:
"""Fetch OAuth token for oauth2_m2m (secret stored directly in record)."""
cache_key = f"{actor_id}:{conn['name']}"
cached = _token_cache.get(cache_key)
if cached and cached['expires_at'] > time.time() + 60:
return cached['token']
data = urllib.parse.urlencode({
'grant_type': 'client_credentials',
'client_id': conn['client_id'],
'client_secret': conn['client_secret'],
'scope': conn.get('scopes', conn.get('scope', '')),
}).encode()
req = urllib.request.Request(conn['token_url'], data=data,
headers={'Content-Type': 'application/x-www-form-urlencoded'})
with urllib.request.urlopen(req, timeout=10) as resp:
body = json.loads(resp.read())
token = body['access_token']
expires_in = body.get('expires_in', 3600)
_token_cache[cache_key] = {'token': token, 'expires_at': time.time() + expires_in}
return token
def _resolve_auth_headers(conn: dict, actor_id: str) -> dict:
"""Resolve auth headers for a connection."""
auth_type = conn.get('auth_type', 'none')
if auth_type == 'oauth_client_credentials':
token = _get_oauth_token(conn, actor_id)
return {'Authorization': f'Bearer {token}'}
elif auth_type == 'oauth2_m2m':
token = _get_m2m_token(conn, actor_id)
return {'Authorization': f'Bearer {token}'}
elif auth_type == 'bearer':
token = _get_ssm_value(conn['token_ssm'])
return {'Authorization': f'Bearer {token}'}
return {}
def invalidate_token(conn_name: str, actor_id: str):
"""Invalidate cached token for a connection (call on auth failure)."""
_token_cache.pop(f"{actor_id}:{conn_name}", None)
def load_mcp_tools(mcp_connections: list, actor_id: str) -> tuple[list, list]:
"""Connect to each enabled MCP server and return (tools_list, clients_to_close).
Returns:
Tuple of (list of MCPClient instances to pass to Agent, list of same clients to close later)
"""
clients = []
for conn in mcp_connections:
if not conn.get('enabled', True):
continue
name = conn.get('name', 'unknown')
try:
headers = _resolve_auth_headers(conn, actor_id)
url = conn['url']
client = MCPClient(lambda u=url, h=headers: streamablehttp_client(u, headers=h))
client.start()
clients.append(client)
logger.info(f'[mcp_loader] Connected to MCP server: {name}')
except Exception as e:
logger.error(f'[mcp_loader] Failed to connect to MCP server "{name}": {e}')
return clients, clients
def close_mcp_clients(clients: list):
"""Close all MCP clients."""
for client in clients:
try:
client.stop()
except Exception as e:
logger.error(f'[mcp_loader] Error closing MCP client: {e}')