- Add oauth2_m2m auth type to mcp_loader.py (client_secret in record, not SSM) - Remove _get_factcloud_token(), FACTCLOUD_* config, factcloud_clients from main.py - Seed Daniel's factcloud connection into enrolled_services.mcp_connections - factcloud now loaded dynamically via mcp_loader at session start
123 lines
4.5 KiB
Python
123 lines
4.5 KiB
Python
"""Dynamic MCP tool loader — connects to user-configured MCP servers and returns their tools."""
|
|
import time
|
|
import logging
|
|
import urllib.request
|
|
import urllib.parse
|
|
import json
|
|
import boto3
|
|
from mcp.client.streamable_http import streamablehttp_client
|
|
from strands.tools.mcp.mcp_client import MCPClient
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Token cache: {f"{actor_id}:{conn_name}": {"token": str, "expires_at": float}}
|
|
_token_cache: dict = {}
|
|
|
|
|
|
def _get_ssm_value(param_name: str) -> str:
|
|
ssm = boto3.client('ssm', region_name='us-east-1')
|
|
return ssm.get_parameter(Name=param_name, WithDecryption=True)['Parameter']['Value']
|
|
|
|
|
|
def _get_oauth_token(conn: dict, actor_id: str) -> str:
|
|
"""Fetch OAuth token via client_credentials grant, with caching."""
|
|
cache_key = f"{actor_id}:{conn['name']}"
|
|
cached = _token_cache.get(cache_key)
|
|
if cached and cached['expires_at'] > time.time():
|
|
return cached['token']
|
|
|
|
client_secret = _get_ssm_value(conn['client_secret_ssm'])
|
|
data = urllib.parse.urlencode({
|
|
'grant_type': 'client_credentials',
|
|
'client_id': conn['client_id'],
|
|
'client_secret': client_secret,
|
|
'scope': conn.get('scope', ''),
|
|
}).encode()
|
|
|
|
req = urllib.request.Request(conn['cognito_token_url'], data=data,
|
|
headers={'Content-Type': 'application/x-www-form-urlencoded'})
|
|
with urllib.request.urlopen(req, timeout=10) as resp:
|
|
body = json.loads(resp.read())
|
|
|
|
token = body['access_token']
|
|
expires_in = body.get('expires_in', 3600)
|
|
_token_cache[cache_key] = {'token': token, 'expires_at': time.time() + expires_in - 30}
|
|
return token
|
|
|
|
|
|
def _get_m2m_token(conn: dict, actor_id: str) -> str:
|
|
"""Fetch OAuth token for oauth2_m2m (secret stored directly in record)."""
|
|
cache_key = f"{actor_id}:{conn['name']}"
|
|
cached = _token_cache.get(cache_key)
|
|
if cached and cached['expires_at'] > time.time() + 60:
|
|
return cached['token']
|
|
|
|
data = urllib.parse.urlencode({
|
|
'grant_type': 'client_credentials',
|
|
'client_id': conn['client_id'],
|
|
'client_secret': conn['client_secret'],
|
|
'scope': conn.get('scopes', conn.get('scope', '')),
|
|
}).encode()
|
|
|
|
req = urllib.request.Request(conn['token_url'], data=data,
|
|
headers={'Content-Type': 'application/x-www-form-urlencoded'})
|
|
with urllib.request.urlopen(req, timeout=10) as resp:
|
|
body = json.loads(resp.read())
|
|
|
|
token = body['access_token']
|
|
expires_in = body.get('expires_in', 3600)
|
|
_token_cache[cache_key] = {'token': token, 'expires_at': time.time() + expires_in}
|
|
return token
|
|
|
|
|
|
def _resolve_auth_headers(conn: dict, actor_id: str) -> dict:
|
|
"""Resolve auth headers for a connection."""
|
|
auth_type = conn.get('auth_type', 'none')
|
|
if auth_type == 'oauth_client_credentials':
|
|
token = _get_oauth_token(conn, actor_id)
|
|
return {'Authorization': f'Bearer {token}'}
|
|
elif auth_type == 'oauth2_m2m':
|
|
token = _get_m2m_token(conn, actor_id)
|
|
return {'Authorization': f'Bearer {token}'}
|
|
elif auth_type == 'bearer':
|
|
token = _get_ssm_value(conn['token_ssm'])
|
|
return {'Authorization': f'Bearer {token}'}
|
|
return {}
|
|
|
|
|
|
def invalidate_token(conn_name: str, actor_id: str):
|
|
"""Invalidate cached token for a connection (call on auth failure)."""
|
|
_token_cache.pop(f"{actor_id}:{conn_name}", None)
|
|
|
|
|
|
def load_mcp_tools(mcp_connections: list, actor_id: str) -> tuple[list, list]:
|
|
"""Connect to each enabled MCP server and return (tools_list, clients_to_close).
|
|
|
|
Returns:
|
|
Tuple of (list of MCPClient instances to pass to Agent, list of same clients to close later)
|
|
"""
|
|
clients = []
|
|
for conn in mcp_connections:
|
|
if not conn.get('enabled', True):
|
|
continue
|
|
name = conn.get('name', 'unknown')
|
|
try:
|
|
headers = _resolve_auth_headers(conn, actor_id)
|
|
url = conn['url']
|
|
client = MCPClient(lambda u=url, h=headers: streamablehttp_client(u, headers=h))
|
|
client.start()
|
|
clients.append(client)
|
|
logger.info(f'[mcp_loader] Connected to MCP server: {name}')
|
|
except Exception as e:
|
|
logger.error(f'[mcp_loader] Failed to connect to MCP server "{name}": {e}')
|
|
return clients, clients
|
|
|
|
|
|
def close_mcp_clients(clients: list):
|
|
"""Close all MCP clients."""
|
|
for client in clients:
|
|
try:
|
|
client.stop()
|
|
except Exception as e:
|
|
logger.error(f'[mcp_loader] Error closing MCP client: {e}')
|