feat: add user-configurable MCP connections
- manage_mcp_connection tool: add/remove/enable/disable/list MCP servers - mcp_loader: dynamic connection with OAuth/bearer/none auth, token caching - Secrets stored in SSM, never in DynamoDB - MCP clients loaded per-session and cleaned up in finally block
This commit is contained in:
@@ -19,7 +19,10 @@ import tools.scheduler as _scheduler_module
|
||||
from tools.home_assistant import home_assistant, set_ha_config
|
||||
from tools.google_workspace import list_calendars, get_calendar_events, list_gmail_messages, get_gmail_message
|
||||
from tools.send_file import send_file as _send_file_impl
|
||||
from tools.mcp_tools import manage_mcp_connection
|
||||
import tools.mcp_tools as _mcp_tools_module
|
||||
import tools.google_workspace as _gws
|
||||
import mcp_loader
|
||||
import httpx
|
||||
import botocore.auth
|
||||
import botocore.awsrequest
|
||||
@@ -313,6 +316,7 @@ async def main(payload: dict, context):
|
||||
_current_chat_id = chat_id
|
||||
_scheduler_module._current_actor_id = actor_id
|
||||
_scheduler_module._current_chat_id = chat_id
|
||||
_mcp_tools_module._current_actor_id = actor_id
|
||||
|
||||
# Run compaction if flagged from previous invocation (trims old events before load)
|
||||
memory_manager.check_and_compact(actor_id, session_id)
|
||||
@@ -381,15 +385,20 @@ async def main(payload: dict, context):
|
||||
|
||||
base_tools = [web_search, web_fetch, read_workspace_file, write_workspace_file,
|
||||
home_assistant, connect_google_account, list_google_accounts, remove_google_account,
|
||||
manage_service, schedule_reminder, list_reminders, cancel_reminder,
|
||||
manage_service, manage_mcp_connection, schedule_reminder, list_reminders, cancel_reminder,
|
||||
list_calendars, get_calendar_events, list_gmail_messages, get_gmail_message,
|
||||
run_code, send_file]
|
||||
|
||||
# Load user's dynamic MCP connections
|
||||
mcp_connections = services.get('mcp_connections', [])
|
||||
mcp_clients, _mcp_to_close = mcp_loader.load_mcp_tools(mcp_connections, actor_id)
|
||||
all_tools = base_tools + mcp_clients
|
||||
|
||||
agent = Agent(
|
||||
model=model,
|
||||
system_prompt=system_prompt,
|
||||
session_manager=session_manager,
|
||||
tools=base_tools,
|
||||
tools=all_tools,
|
||||
)
|
||||
|
||||
final_message = None
|
||||
@@ -407,6 +416,7 @@ async def main(payload: dict, context):
|
||||
finally:
|
||||
_typing_active = False
|
||||
session_manager.close()
|
||||
mcp_loader.close_mcp_clients(_mcp_to_close)
|
||||
# Check if session exceeds window — flag for compaction on next invocation
|
||||
memory_manager.check_window_and_flag(actor_id, session_id)
|
||||
|
||||
|
||||
94
agentclaw/app/agent_claw_main/mcp_loader.py
Normal file
94
agentclaw/app/agent_claw_main/mcp_loader.py
Normal file
@@ -0,0 +1,94 @@
|
||||
"""Dynamic MCP tool loader — connects to user-configured MCP servers and returns their tools."""
|
||||
import time
|
||||
import logging
|
||||
import urllib.request
|
||||
import urllib.parse
|
||||
import json
|
||||
import boto3
|
||||
from mcp.client.streamable_http import streamablehttp_client
|
||||
from strands.tools.mcp.mcp_client import MCPClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Token cache: {f"{actor_id}:{conn_name}": {"token": str, "expires_at": float}}
|
||||
_token_cache: dict = {}
|
||||
|
||||
|
||||
def _get_ssm_value(param_name: str) -> str:
|
||||
ssm = boto3.client('ssm', region_name='us-east-1')
|
||||
return ssm.get_parameter(Name=param_name, WithDecryption=True)['Parameter']['Value']
|
||||
|
||||
|
||||
def _get_oauth_token(conn: dict, actor_id: str) -> str:
|
||||
"""Fetch OAuth token via client_credentials grant, with caching."""
|
||||
cache_key = f"{actor_id}:{conn['name']}"
|
||||
cached = _token_cache.get(cache_key)
|
||||
if cached and cached['expires_at'] > time.time():
|
||||
return cached['token']
|
||||
|
||||
client_secret = _get_ssm_value(conn['client_secret_ssm'])
|
||||
data = urllib.parse.urlencode({
|
||||
'grant_type': 'client_credentials',
|
||||
'client_id': conn['client_id'],
|
||||
'client_secret': client_secret,
|
||||
'scope': conn.get('scope', ''),
|
||||
}).encode()
|
||||
|
||||
req = urllib.request.Request(conn['cognito_token_url'], data=data,
|
||||
headers={'Content-Type': 'application/x-www-form-urlencoded'})
|
||||
with urllib.request.urlopen(req, timeout=10) as resp:
|
||||
body = json.loads(resp.read())
|
||||
|
||||
token = body['access_token']
|
||||
expires_in = body.get('expires_in', 3600)
|
||||
_token_cache[cache_key] = {'token': token, 'expires_at': time.time() + expires_in - 30}
|
||||
return token
|
||||
|
||||
|
||||
def _resolve_auth_headers(conn: dict, actor_id: str) -> dict:
|
||||
"""Resolve auth headers for a connection."""
|
||||
auth_type = conn.get('auth_type', 'none')
|
||||
if auth_type == 'oauth_client_credentials':
|
||||
token = _get_oauth_token(conn, actor_id)
|
||||
return {'Authorization': f'Bearer {token}'}
|
||||
elif auth_type == 'bearer':
|
||||
token = _get_ssm_value(conn['token_ssm'])
|
||||
return {'Authorization': f'Bearer {token}'}
|
||||
return {}
|
||||
|
||||
|
||||
def invalidate_token(conn_name: str, actor_id: str):
|
||||
"""Invalidate cached token for a connection (call on auth failure)."""
|
||||
_token_cache.pop(f"{actor_id}:{conn_name}", None)
|
||||
|
||||
|
||||
def load_mcp_tools(mcp_connections: list, actor_id: str) -> tuple[list, list]:
|
||||
"""Connect to each enabled MCP server and return (tools_list, clients_to_close).
|
||||
|
||||
Returns:
|
||||
Tuple of (list of MCPClient instances to pass to Agent, list of same clients to close later)
|
||||
"""
|
||||
clients = []
|
||||
for conn in mcp_connections:
|
||||
if not conn.get('enabled', True):
|
||||
continue
|
||||
name = conn.get('name', 'unknown')
|
||||
try:
|
||||
headers = _resolve_auth_headers(conn, actor_id)
|
||||
url = conn['url']
|
||||
client = MCPClient(lambda u=url, h=headers: streamablehttp_client(u, headers=h))
|
||||
client.start()
|
||||
clients.append(client)
|
||||
logger.info(f'[mcp_loader] Connected to MCP server: {name}')
|
||||
except Exception as e:
|
||||
logger.error(f'[mcp_loader] Failed to connect to MCP server "{name}": {e}')
|
||||
return clients, clients
|
||||
|
||||
|
||||
def close_mcp_clients(clients: list):
|
||||
"""Close all MCP clients."""
|
||||
for client in clients:
|
||||
try:
|
||||
client.stop()
|
||||
except Exception as e:
|
||||
logger.error(f'[mcp_loader] Error closing MCP client: {e}')
|
||||
148
agentclaw/app/agent_claw_main/tools/mcp_tools.py
Normal file
148
agentclaw/app/agent_claw_main/tools/mcp_tools.py
Normal file
@@ -0,0 +1,148 @@
|
||||
"""MCP connection management tool — add/remove/enable/disable user MCP servers."""
|
||||
import boto3
|
||||
from strands import tool
|
||||
|
||||
USERS_TABLE_NAME = 'agent-claw-users'
|
||||
|
||||
# Set per-invocation by main.py
|
||||
_current_actor_id: str = ''
|
||||
|
||||
|
||||
@tool
|
||||
def manage_mcp_connection(action: str, name: str = '', url: str = '',
|
||||
auth_type: str = 'none', cognito_token_url: str = '',
|
||||
client_id: str = '', client_secret: str = '',
|
||||
scope: str = '', token: str = '') -> str:
|
||||
"""Add, remove, enable, disable, or list MCP server connections.
|
||||
|
||||
Actions: add, remove, enable, disable, list
|
||||
|
||||
For add with auth_type=oauth_client_credentials, provide:
|
||||
- cognito_token_url: Cognito token endpoint
|
||||
- client_id: OAuth client ID
|
||||
- client_secret: Secret value (stored securely in SSM, not in database)
|
||||
- scope: Space-separated OAuth scopes
|
||||
|
||||
For add with auth_type=bearer, provide:
|
||||
- token: Bearer token value (stored securely in SSM, not in database)
|
||||
|
||||
For add with auth_type=none, only name and url are required.
|
||||
|
||||
Args:
|
||||
action: One of "add", "remove", "enable", "disable", "list".
|
||||
name: Connection name (required for add/remove/enable/disable).
|
||||
url: MCP server URL (required for add).
|
||||
auth_type: One of "none", "bearer", "oauth_client_credentials".
|
||||
cognito_token_url: Token endpoint for oauth_client_credentials.
|
||||
client_id: OAuth client ID for oauth_client_credentials.
|
||||
client_secret: OAuth client secret (will be stored in SSM).
|
||||
scope: OAuth scopes for oauth_client_credentials.
|
||||
token: Bearer token (will be stored in SSM).
|
||||
"""
|
||||
actor_id = _current_actor_id
|
||||
if not actor_id:
|
||||
return 'Cannot determine actor_id.'
|
||||
|
||||
ddb = boto3.resource('dynamodb', region_name='us-east-1')
|
||||
table = ddb.Table(USERS_TABLE_NAME)
|
||||
|
||||
if action == 'list':
|
||||
resp = table.get_item(Key={'actor_id': actor_id})
|
||||
connections = resp.get('Item', {}).get('services', {}).get('mcp_connections', [])
|
||||
if not connections:
|
||||
return 'No MCP connections configured.'
|
||||
lines = []
|
||||
for c in connections:
|
||||
status = '✓' if c.get('enabled', True) else '✗'
|
||||
lines.append(f" {status} {c['name']}: {c['url']} (auth: {c.get('auth_type', 'none')})")
|
||||
return 'MCP connections:\n' + '\n'.join(lines)
|
||||
|
||||
if not name:
|
||||
return 'name is required for this action.'
|
||||
|
||||
if action == 'add':
|
||||
if not url:
|
||||
return 'url is required for add.'
|
||||
if auth_type not in ('none', 'bearer', 'oauth_client_credentials'):
|
||||
return f'Invalid auth_type: {auth_type}. Use none, bearer, or oauth_client_credentials.'
|
||||
|
||||
ssm = boto3.client('ssm', region_name='us-east-1')
|
||||
safe_actor = actor_id.replace(':', '-')
|
||||
ssm_prefix = f'/agent-claw/mcp/{safe_actor}/{name}'
|
||||
|
||||
conn = {'name': name, 'url': url, 'auth_type': auth_type, 'enabled': True}
|
||||
|
||||
if auth_type == 'oauth_client_credentials':
|
||||
if not cognito_token_url or not client_id or not client_secret:
|
||||
return 'oauth_client_credentials requires cognito_token_url, client_id, and client_secret.'
|
||||
ssm.put_parameter(Name=f'{ssm_prefix}/client-secret', Value=client_secret,
|
||||
Type='SecureString', Overwrite=True)
|
||||
conn['cognito_token_url'] = cognito_token_url
|
||||
conn['client_id'] = client_id
|
||||
conn['client_secret_ssm'] = f'{ssm_prefix}/client-secret'
|
||||
conn['scope'] = scope
|
||||
elif auth_type == 'bearer':
|
||||
if not token:
|
||||
return 'bearer auth requires token.'
|
||||
ssm.put_parameter(Name=f'{ssm_prefix}/token', Value=token,
|
||||
Type='SecureString', Overwrite=True)
|
||||
conn['token_ssm'] = f'{ssm_prefix}/token'
|
||||
|
||||
# Upsert into mcp_connections list
|
||||
resp = table.get_item(Key={'actor_id': actor_id})
|
||||
services = resp.get('Item', {}).get('services', {})
|
||||
connections = services.get('mcp_connections', [])
|
||||
connections = [c for c in connections if c['name'] != name]
|
||||
connections.append(conn)
|
||||
|
||||
table.update_item(
|
||||
Key={'actor_id': actor_id},
|
||||
UpdateExpression='SET services = if_not_exists(services, :empty), services.mcp_connections = :conns',
|
||||
ExpressionAttributeValues={':conns': connections, ':empty': {}},
|
||||
)
|
||||
return f'MCP connection "{name}" added ({auth_type} auth). It will be available on your next message.'
|
||||
|
||||
elif action == 'remove':
|
||||
resp = table.get_item(Key={'actor_id': actor_id})
|
||||
connections = resp.get('Item', {}).get('services', {}).get('mcp_connections', [])
|
||||
found = [c for c in connections if c['name'] == name]
|
||||
if not found:
|
||||
return f'No connection named "{name}" found.'
|
||||
|
||||
# Clean up SSM secrets
|
||||
ssm = boto3.client('ssm', region_name='us-east-1')
|
||||
safe_actor = actor_id.replace(':', '-')
|
||||
for key in ('client-secret', 'token'):
|
||||
try:
|
||||
ssm.delete_parameter(Name=f'/agent-claw/mcp/{safe_actor}/{name}/{key}')
|
||||
except ssm.exceptions.ParameterNotFound:
|
||||
pass
|
||||
|
||||
connections = [c for c in connections if c['name'] != name]
|
||||
table.update_item(
|
||||
Key={'actor_id': actor_id},
|
||||
UpdateExpression='SET services.mcp_connections = :conns',
|
||||
ExpressionAttributeValues={':conns': connections},
|
||||
)
|
||||
return f'MCP connection "{name}" removed.'
|
||||
|
||||
elif action in ('enable', 'disable'):
|
||||
resp = table.get_item(Key={'actor_id': actor_id})
|
||||
connections = resp.get('Item', {}).get('services', {}).get('mcp_connections', [])
|
||||
updated = False
|
||||
for c in connections:
|
||||
if c['name'] == name:
|
||||
c['enabled'] = (action == 'enable')
|
||||
updated = True
|
||||
break
|
||||
if not updated:
|
||||
return f'No connection named "{name}" found.'
|
||||
table.update_item(
|
||||
Key={'actor_id': actor_id},
|
||||
UpdateExpression='SET services.mcp_connections = :conns',
|
||||
ExpressionAttributeValues={':conns': connections},
|
||||
)
|
||||
return f'MCP connection "{name}" {action}d.'
|
||||
|
||||
else:
|
||||
return f'Unknown action: {action}. Use add, remove, enable, disable, or list.'
|
||||
Reference in New Issue
Block a user