feat: add user-configurable MCP connections
- manage_mcp_connection tool: add/remove/enable/disable/list MCP servers - mcp_loader: dynamic connection with OAuth/bearer/none auth, token caching - Secrets stored in SSM, never in DynamoDB - MCP clients loaded per-session and cleaned up in finally block
This commit is contained in:
94
agentclaw/app/agent_claw_main/mcp_loader.py
Normal file
94
agentclaw/app/agent_claw_main/mcp_loader.py
Normal file
@@ -0,0 +1,94 @@
|
||||
"""Dynamic MCP tool loader — connects to user-configured MCP servers and returns their tools."""
|
||||
import time
|
||||
import logging
|
||||
import urllib.request
|
||||
import urllib.parse
|
||||
import json
|
||||
import boto3
|
||||
from mcp.client.streamable_http import streamablehttp_client
|
||||
from strands.tools.mcp.mcp_client import MCPClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Token cache: {f"{actor_id}:{conn_name}": {"token": str, "expires_at": float}}
|
||||
_token_cache: dict = {}
|
||||
|
||||
|
||||
def _get_ssm_value(param_name: str) -> str:
|
||||
ssm = boto3.client('ssm', region_name='us-east-1')
|
||||
return ssm.get_parameter(Name=param_name, WithDecryption=True)['Parameter']['Value']
|
||||
|
||||
|
||||
def _get_oauth_token(conn: dict, actor_id: str) -> str:
|
||||
"""Fetch OAuth token via client_credentials grant, with caching."""
|
||||
cache_key = f"{actor_id}:{conn['name']}"
|
||||
cached = _token_cache.get(cache_key)
|
||||
if cached and cached['expires_at'] > time.time():
|
||||
return cached['token']
|
||||
|
||||
client_secret = _get_ssm_value(conn['client_secret_ssm'])
|
||||
data = urllib.parse.urlencode({
|
||||
'grant_type': 'client_credentials',
|
||||
'client_id': conn['client_id'],
|
||||
'client_secret': client_secret,
|
||||
'scope': conn.get('scope', ''),
|
||||
}).encode()
|
||||
|
||||
req = urllib.request.Request(conn['cognito_token_url'], data=data,
|
||||
headers={'Content-Type': 'application/x-www-form-urlencoded'})
|
||||
with urllib.request.urlopen(req, timeout=10) as resp:
|
||||
body = json.loads(resp.read())
|
||||
|
||||
token = body['access_token']
|
||||
expires_in = body.get('expires_in', 3600)
|
||||
_token_cache[cache_key] = {'token': token, 'expires_at': time.time() + expires_in - 30}
|
||||
return token
|
||||
|
||||
|
||||
def _resolve_auth_headers(conn: dict, actor_id: str) -> dict:
|
||||
"""Resolve auth headers for a connection."""
|
||||
auth_type = conn.get('auth_type', 'none')
|
||||
if auth_type == 'oauth_client_credentials':
|
||||
token = _get_oauth_token(conn, actor_id)
|
||||
return {'Authorization': f'Bearer {token}'}
|
||||
elif auth_type == 'bearer':
|
||||
token = _get_ssm_value(conn['token_ssm'])
|
||||
return {'Authorization': f'Bearer {token}'}
|
||||
return {}
|
||||
|
||||
|
||||
def invalidate_token(conn_name: str, actor_id: str):
|
||||
"""Invalidate cached token for a connection (call on auth failure)."""
|
||||
_token_cache.pop(f"{actor_id}:{conn_name}", None)
|
||||
|
||||
|
||||
def load_mcp_tools(mcp_connections: list, actor_id: str) -> tuple[list, list]:
|
||||
"""Connect to each enabled MCP server and return (tools_list, clients_to_close).
|
||||
|
||||
Returns:
|
||||
Tuple of (list of MCPClient instances to pass to Agent, list of same clients to close later)
|
||||
"""
|
||||
clients = []
|
||||
for conn in mcp_connections:
|
||||
if not conn.get('enabled', True):
|
||||
continue
|
||||
name = conn.get('name', 'unknown')
|
||||
try:
|
||||
headers = _resolve_auth_headers(conn, actor_id)
|
||||
url = conn['url']
|
||||
client = MCPClient(lambda u=url, h=headers: streamablehttp_client(u, headers=h))
|
||||
client.start()
|
||||
clients.append(client)
|
||||
logger.info(f'[mcp_loader] Connected to MCP server: {name}')
|
||||
except Exception as e:
|
||||
logger.error(f'[mcp_loader] Failed to connect to MCP server "{name}": {e}')
|
||||
return clients, clients
|
||||
|
||||
|
||||
def close_mcp_clients(clients: list):
|
||||
"""Close all MCP clients."""
|
||||
for client in clients:
|
||||
try:
|
||||
client.stop()
|
||||
except Exception as e:
|
||||
logger.error(f'[mcp_loader] Error closing MCP client: {e}')
|
||||
Reference in New Issue
Block a user