Multi-account Google support with user labels

This commit is contained in:
daniel
2026-05-09 11:21:37 -05:00
parent 01b258579b
commit 38d828ef74
4 changed files with 239 additions and 141 deletions

View File

@@ -4,21 +4,26 @@ Mirrors workspace-mcp (gcalendar/gmail) logic using google-api-python-client dir
since workspace-mcp tool functions require FastMCP request context and cannot be called
outside an MCP server.
Credential secret: agent-claw/google-credentials/{actor_id.replace(':', '-')}
Contains: token, refresh_token, token_uri, client_id, client_secret, scopes
Credential secrets: agent-claw/google-credentials/{safe_actor_id}/{label}
Backward compat: agent-claw/google-credentials/{safe_actor_id} (treated as "primary")
"""
import json
import time
import traceback
import boto3
import httplib2
from strands import tool
from google.oauth2.credentials import Credentials
from google.auth.transport.requests import Request
from googleapiclient.discovery import build
from datetime import datetime, timezone, timedelta
_HTTP_TIMEOUT = 15
_sm = None
# Cache: actor_id -> (timestamp, {label: Credentials})
_creds_cache: dict[str, tuple[float, dict[str, Credentials]]] = {}
# Set per-invocation by main.py
_current_actor_id: str = ''
_current_google_accounts: dict = {} # {label: email} from DynamoDB
def _secrets():
@@ -28,24 +33,24 @@ def _secrets():
return _sm
def _get_creds(actor_id: str) -> Credentials:
from datetime import datetime
secret_name = 'agent-claw/google-credentials/' + actor_id.replace(':', '-')
print(f'[google] fetching creds actor={actor_id}')
data = json.loads(_secrets().get_secret_value(SecretId=secret_name)['SecretString'])
def _actor_id():
return _current_actor_id
def _load_creds_from_secret(secret_name: str) -> Credentials:
"""Load, optionally refresh, and return Credentials from a named secret."""
sm = _secrets()
data = json.loads(sm.get_secret_value(SecretId=secret_name)['SecretString'])
expiry_str = data.get('expiry')
expiry = None
if expiry_str:
exp_aware = datetime.fromisoformat(expiry_str.replace('Z', '+00:00'))
expiry = exp_aware.replace(tzinfo=None) # google-auth uses naive UTC datetimes
else:
expiry = None
expiry = exp_aware.replace(tzinfo=None)
stored_scopes = data.get('scopes', [])
api_scopes = [s for s in stored_scopes if s.startswith('https://')] if stored_scopes else None
# Fix stored scopes if they contain OIDC scopes
if stored_scopes and any(s in stored_scopes for s in ['openid', 'email', 'profile']):
data['scopes'] = api_scopes
_secrets().put_secret_value(SecretId=secret_name, SecretString=json.dumps(data))
print('[google] fixed stored scopes: removed OIDC scopes')
sm.put_secret_value(SecretId=secret_name, SecretString=json.dumps(data))
creds = Credentials(
token=data.get('token'),
refresh_token=data.get('refresh_token'),
@@ -55,51 +60,96 @@ def _get_creds(actor_id: str) -> Credentials:
scopes=api_scopes,
expiry=expiry,
)
print(f'[google] creds loaded, expired={creds.expired}')
if (creds.expired or not creds.valid) and creds.refresh_token:
print('[google] refreshing token')
creds.refresh(Request())
data['token'] = creds.token
if creds.expiry:
data['expiry'] = creds.expiry.isoformat()
_secrets().put_secret_value(SecretId=secret_name, SecretString=json.dumps(data))
print('[google] token refreshed and saved')
sm.put_secret_value(SecretId=secret_name, SecretString=json.dumps(data))
return creds
def _actor_id():
# Read from module-level var set by main.py per invocation
# DO NOT use 'import main as _main' — it re-runs main.py including app.run() which hangs
return _current_actor_id
def _load_all_creds(actor_id: str) -> dict[str, Credentials]:
"""Load all labeled credentials for actor_id, with 5-min TTL cache."""
now = time.time()
if actor_id in _creds_cache:
ts, cached = _creds_cache[actor_id]
if now - ts < 300:
return cached
safe = actor_id.replace(':', '-').replace('/', '-')
prefix = f'agent-claw/google-credentials/{safe}/'
sm = _secrets()
result: dict[str, Credentials] = {}
# Set per-invocation by main.py before any tool call
_current_actor_id: str = ''
try:
paginator = sm.get_paginator('list_secrets')
for page in paginator.paginate(Filters=[{'Key': 'name', 'Values': [prefix]}]):
for secret in page.get('SecretList', []):
name = secret['Name']
label = name[len(prefix):]
if not label or '/' in label:
continue
try:
result[label] = _load_creds_from_secret(name)
print(f'[google] loaded creds actor={actor_id} label={label}')
except Exception as e:
print(f'[google] failed to load label={label}: {e}')
except Exception as e:
print(f'[google] list_secrets failed: {e}')
# Backward compat: flat secret path
if not result:
flat = f'agent-claw/google-credentials/{safe}'
try:
result['primary'] = _load_creds_from_secret(flat)
print(f'[google] loaded creds from flat path actor={actor_id}')
except Exception:
pass
_creds_cache[actor_id] = (now, result)
return result
def _svc(api: str, version: str, creds: Credentials):
# Standard google-auth pattern: pass credentials= directly to build().
# google.oauth2.credentials.Credentials is natively supported by googleapiclient.
# (creds.authorize() is oauth2client only — not available here)
return build(api, version, credentials=creds, cache_discovery=False)
def _get_creds_for_label(all_creds: dict[str, Credentials], label: str | None):
"""Return {label: creds} filtered by label, or all if label is None."""
if label:
if label not in all_creds:
return {}
return {label: all_creds[label]}
return all_creds
@tool
def list_calendars() -> str:
"""List all Google Calendars for the current user."""
def list_calendars(account_label: str = None) -> str:
"""List all Google Calendars for the current user.
Args:
account_label: Optional account label (e.g. 'work', 'personal'). Lists all accounts if omitted.
"""
try:
creds = _get_creds(_actor_id())
result = _svc('calendar', 'v3', creds).calendarList().list().execute()
items = result.get('items', [])
if not items:
return 'No calendars found.'
return '\n'.join(
f'- "{c.get("summary", "")}"{" (Primary)" if c.get("primary") else ""} (ID: {c["id"]})'
for c in items
)
all_creds = _load_all_creds(_actor_id())
if not all_creds:
return 'No Google accounts connected. Use connect_google_account to add one.'
creds_map = _get_creds_for_label(all_creds, account_label)
if not creds_map:
return f'No account with label "{account_label}" found.'
multi = len(creds_map) > 1
parts = []
for label, creds in creds_map.items():
items = _svc('calendar', 'v3', creds).calendarList().list().execute().get('items', [])
lines = [
f'{"[" + label + "] " if multi else ""}- "{c.get("summary", "")}"{" (Primary)" if c.get("primary") else ""} (ID: {c["id"]})'
for c in items
]
parts.append('\n'.join(lines) if lines else f'{"[" + label + "] " if multi else ""}No calendars found.')
return '\n'.join(parts)
except Exception as e:
tb = traceback.format_exc()
print(f'[google] list_calendars error: {e}\n{tb}')
print(f'[google] list_calendars error: {e}\n{traceback.format_exc()}')
return f'Error listing calendars: {e}'
@@ -111,6 +161,7 @@ def get_calendar_events(
time_max: str = '',
max_results: int = 25,
query: str = '',
account_label: str = None,
) -> str:
"""Get upcoming Google Calendar events.
@@ -121,10 +172,16 @@ def get_calendar_events(
time_max: End of time range in RFC3339 format (optional)
max_results: Maximum events to return (default: 25)
query: Keyword search within event fields (optional)
account_label: Optional account label (e.g. 'work', 'personal'). Queries all accounts if omitted.
"""
try:
creds = _get_creds(_actor_id())
svc = _svc('calendar', 'v3', creds)
all_creds = _load_all_creds(_actor_id())
if not all_creds:
return 'No Google accounts connected.'
creds_map = _get_creds_for_label(all_creds, account_label)
if not creds_map:
return f'No account with label "{account_label}" found.'
multi = len(creds_map) > 1
now = datetime.now(timezone.utc)
params = {
'calendarId': calendar_id,
@@ -136,92 +193,118 @@ def get_calendar_events(
}
if query:
params['q'] = query
events = svc.events().list(**params).execute().get('items', [])
if not events:
return f'No events found in calendar "{calendar_id}".'
lines = []
for e in events:
start = e['start'].get('dateTime', e['start'].get('date', ''))
end = e['end'].get('dateTime', e['end'].get('date', ''))
eid = e.get('id', '')
lines.append(f'- "{e.get("summary", "No Title")}" (Starts: {start}, Ends: {end}) ID: {eid}')
return f'Retrieved {len(events)} events from "{calendar_id}":\n' + '\n'.join(lines)
parts = []
for label, creds in creds_map.items():
events = _svc('calendar', 'v3', creds).events().list(**params).execute().get('items', [])
if not events:
parts.append(f'{"[" + label + "] " if multi else ""}No events found in calendar "{calendar_id}".')
continue
lines = []
for e in events:
start = e['start'].get('dateTime', e['start'].get('date', ''))
end = e['end'].get('dateTime', e['end'].get('date', ''))
prefix = f'[{label}] ' if multi else ''
lines.append(f'{prefix}- "{e.get("summary", "No Title")}" (Starts: {start}, Ends: {end}) ID: {e.get("id", "")}')
parts.append(f'Retrieved {len(events)} events{" [" + label + "]" if multi else ""} from "{calendar_id}":\n' + '\n'.join(lines))
return '\n\n'.join(parts)
except Exception as e:
tb = traceback.format_exc()
print(f'[google] get_calendar_events error: {e}\n{tb}')
print(f'[google] get_calendar_events error: {e}\n{traceback.format_exc()}')
return f'Error fetching calendar events: {e}'
@tool
def list_gmail_messages(max_results: int = 10, query: str = 'in:inbox') -> str:
def list_gmail_messages(max_results: int = 10, query: str = 'in:inbox', account_label: str = None) -> str:
"""List Gmail messages.
Args:
max_results: Maximum number of messages to return (default: 10)
query: Gmail search query (default: 'in:inbox')
account_label: Optional account label (e.g. 'work', 'personal'). Lists all accounts if omitted.
"""
try:
creds = _get_creds(_actor_id())
svc = _svc('gmail', 'v1', creds)
result = svc.users().messages().list(userId='me', q=query, maxResults=max_results).execute()
messages = result.get('messages', [])
if not messages:
return 'No messages found.'
lines = []
for m in messages:
msg = svc.users().messages().get(
userId='me', id=m['id'], format='metadata',
metadataHeaders=['Subject', 'From', 'Date']
).execute()
h = {hdr['name']: hdr['value'] for hdr in msg.get('payload', {}).get('headers', [])}
lines.append(f"id={m['id']} | {h.get('Date', '')} | From: {h.get('From', '')} | {h.get('Subject', '(no subject)')}")
next_token = result.get('nextPageToken')
out = '\n'.join(lines)
if next_token:
out += f'\n(more results available)'
return out
all_creds = _load_all_creds(_actor_id())
if not all_creds:
return 'No Google accounts connected.'
creds_map = _get_creds_for_label(all_creds, account_label)
if not creds_map:
return f'No account with label "{account_label}" found.'
multi = len(creds_map) > 1
parts = []
for label, creds in creds_map.items():
svc = _svc('gmail', 'v1', creds)
result = svc.users().messages().list(userId='me', q=query, maxResults=max_results).execute()
messages = result.get('messages', [])
if not messages:
parts.append(f'{"[" + label + "] " if multi else ""}No messages found.')
continue
lines = []
for m in messages:
msg = svc.users().messages().get(
userId='me', id=m['id'], format='metadata',
metadataHeaders=['Subject', 'From', 'Date']
).execute()
h = {hdr['name']: hdr['value'] for hdr in msg.get('payload', {}).get('headers', [])}
prefix = f'[{label}] ' if multi else ''
lines.append(f"{prefix}id={m['id']} | {h.get('Date', '')} | From: {h.get('From', '')} | {h.get('Subject', '(no subject)')}")
if result.get('nextPageToken'):
lines.append(f'{"[" + label + "] " if multi else ""}(more results available)')
parts.append('\n'.join(lines))
return '\n'.join(parts)
except Exception as e:
tb = traceback.format_exc()
print(f'[google] list_gmail_messages error: {e}\n{tb}')
print(f'[google] list_gmail_messages error: {e}\n{traceback.format_exc()}')
return f'Error listing Gmail messages: {e}'
@tool
def get_gmail_message(message_id: str, body_format: str = 'text') -> str:
def get_gmail_message(message_id: str, body_format: str = 'text', account_label: str = None) -> str:
"""Get the full content of a Gmail message by ID.
Args:
message_id: The Gmail message ID
body_format: 'text' (default), 'html', or 'raw'
account_label: Optional account label. Tries all accounts if omitted.
"""
try:
creds = _get_creds(_actor_id())
svc = _svc('gmail', 'v1', creds)
meta = svc.users().messages().get(
userId='me', id=message_id, format='metadata',
metadataHeaders=['Subject', 'From', 'To', 'Cc', 'Date']
).execute()
h = {hdr['name']: hdr['value'] for hdr in meta.get('payload', {}).get('headers', [])}
if body_format == 'raw':
import base64
raw = svc.users().messages().get(userId='me', id=message_id, format='raw').execute()
body = base64.urlsafe_b64decode(raw.get('raw', '') + '==').decode('utf-8', errors='replace')
else:
full = svc.users().messages().get(userId='me', id=message_id, format='full').execute()
body = _extract_body(full.get('payload', {}), prefer_html=(body_format == 'html'))
lines = [
f"From: {h.get('From', '')}",
f"To: {h.get('To', '')}",
f"Date: {h.get('Date', '')}",
f"Subject: {h.get('Subject', '')}",
'',
body,
]
if h.get('Cc'):
lines.insert(3, f"Cc: {h['Cc']}")
return '\n'.join(lines)
all_creds = _load_all_creds(_actor_id())
if not all_creds:
return 'No Google accounts connected.'
creds_map = _get_creds_for_label(all_creds, account_label)
if not creds_map:
return f'No account with label "{account_label}" found.'
multi = len(creds_map) > 1
for label, creds in creds_map.items():
try:
svc = _svc('gmail', 'v1', creds)
meta = svc.users().messages().get(
userId='me', id=message_id, format='metadata',
metadataHeaders=['Subject', 'From', 'To', 'Cc', 'Date']
).execute()
h = {hdr['name']: hdr['value'] for hdr in meta.get('payload', {}).get('headers', [])}
if body_format == 'raw':
import base64
raw = svc.users().messages().get(userId='me', id=message_id, format='raw').execute()
body = base64.urlsafe_b64decode(raw.get('raw', '') + '==').decode('utf-8', errors='replace')
else:
full = svc.users().messages().get(userId='me', id=message_id, format='full').execute()
body = _extract_body(full.get('payload', {}), prefer_html=(body_format == 'html'))
prefix = f'[{label}]\n' if multi else ''
lines = [
f"{prefix}From: {h.get('From', '')}",
f"To: {h.get('To', '')}",
f"Date: {h.get('Date', '')}",
f"Subject: {h.get('Subject', '')}",
'',
body,
]
if h.get('Cc'):
lines.insert(3, f"Cc: {h['Cc']}")
return '\n'.join(lines)
except Exception as e:
if multi:
print(f'[google] get_gmail_message label={label} not found: {e}')
continue
return f'Error fetching Gmail message: {e}'
return f'Message {message_id} not found in any connected account.'
except Exception as e:
return f'Error fetching Gmail message: {e}'
@@ -237,17 +320,14 @@ def _extract_body(payload: dict, prefer_html: bool = False) -> str:
return base64.urlsafe_b64decode(data + '==').decode('utf-8', errors='replace') if data else ''
parts = payload.get('parts', [])
# First pass: preferred type
for part in parts:
if part.get('mimeType') == target:
data = part.get('body', {}).get('data', '')
return base64.urlsafe_b64decode(data + '==').decode('utf-8', errors='replace') if data else ''
# Second pass: fallback type
for part in parts:
if part.get('mimeType') == fallback:
data = part.get('body', {}).get('data', '')
return base64.urlsafe_b64decode(data + '==').decode('utf-8', errors='replace') if data else ''
# Recurse into multipart
for part in parts:
text = _extract_body(part, prefer_html)
if text: