Multi-account Google support with user labels

This commit is contained in:
daniel
2026-05-09 11:21:37 -05:00
parent 01b258579b
commit 38d828ef74
4 changed files with 239 additions and 141 deletions

View File

@@ -94,8 +94,8 @@ def write_workspace_file(path: str, content: str) -> str:
@tool @tool
def connect_google_account() -> str: def connect_google_account(label: str = 'primary') -> str:
"""Generate a Google OAuth authorization URL for the current user to connect their Google account. """Connect a Google account with a custom label (e.g. 'work', 'personal'). Defaults to 'primary'.
Use this when the user wants to connect Google Workspace (Gmail, Calendar, Drive, etc.) Use this when the user wants to connect Google Workspace (Gmail, Calendar, Drive, etc.)
or when Google tools fail due to missing credentials.""" or when Google tools fail due to missing credentials."""
if not OAUTH_START_URL: if not OAUTH_START_URL:
@@ -103,8 +103,18 @@ def connect_google_account() -> str:
actor_id = _current_actor_id actor_id = _current_actor_id
if not actor_id: if not actor_id:
return 'Cannot determine actor_id for OAuth flow.' return 'Cannot determine actor_id for OAuth flow.'
url = f'{OAUTH_START_URL}?actor_id={actor_id}' url = f'{OAUTH_START_URL}?actor_id={actor_id}&label={label}'
return f'Please open this URL to connect your Google account:\n{url}\n\nAfter authorizing, Google Workspace tools (Gmail, Calendar, Drive) will be available.' return f'Please open this URL to connect your Google account as "{label}":\n{url}\n\nAfter authorizing, Google Workspace tools (Gmail, Calendar, Drive) will be available.'
@tool
def list_google_accounts() -> str:
"""List all connected Google accounts and their labels."""
accounts = _gws._current_google_accounts
if not accounts:
return 'No Google accounts connected. Use connect_google_account to add one.'
parts = [f'{label} ({email})' for label, email in accounts.items()]
return 'Connected Google accounts: ' + ', '.join(parts)
@tool @tool
@@ -145,11 +155,9 @@ def manage_service(action: str, service: str, config: dict | None = None) -> str
return 'service name is required.' return 'service name is required.'
if not config: if not config:
return 'config dict is required for enroll.' return 'config dict is required for enroll.'
# Validate known services
if service == 'home_assistant': if service == 'home_assistant':
if 'url' not in config or 'token' not in config: if 'url' not in config or 'token' not in config:
return 'home_assistant config requires "url" and "token" keys.' return 'home_assistant config requires "url" and "token" keys.'
# Update in-memory config immediately for this session
set_ha_config(config['url'], config['token']) set_ha_config(config['url'], config['token'])
table.update_item( table.update_item(
Key={'actor_id': actor_id}, Key={'actor_id': actor_id},
@@ -245,17 +253,21 @@ async def main(payload: dict, context):
ha_cfg = services.get('home_assistant', {}) ha_cfg = services.get('home_assistant', {})
set_ha_config(ha_cfg.get('url', ''), ha_cfg.get('token', '')) set_ha_config(ha_cfg.get('url', ''), ha_cfg.get('token', ''))
# Sync google_accounts to google_workspace module
google_accounts = user_profile.get('google_accounts', {})
_gws._current_google_accounts = google_accounts
# Build system prompt — base cached, user context injected per-invocation # Build system prompt — base cached, user context injected per-invocation
user_context = '' user_context = ''
if user_profile: if user_profile:
name = user_profile.get('display_name', '') name = user_profile.get('display_name', '')
username = user_profile.get('telegram_username', '') username = user_profile.get('telegram_username', '')
google_email = user_profile.get('google_email', '')
user_context = f'Name: {name}' user_context = f'Name: {name}'
if username: if username:
user_context += f'\nTelegram username: @{username}' user_context += f'\nTelegram username: @{username}'
if google_email: if google_accounts:
user_context += f'\nGoogle account: {google_email}' acct_list = ', '.join(f'{label} ({email})' for label, email in google_accounts.items())
user_context += f'\nGoogle accounts: {acct_list}'
else: else:
user_context += '\nGoogle account: not connected (use connect_google_account tool to connect)' user_context += '\nGoogle account: not connected (use connect_google_account tool to connect)'
enrolled = list(services.keys()) enrolled = list(services.keys())
@@ -282,7 +294,7 @@ async def main(payload: dict, context):
) )
base_tools = [web_search, web_fetch, read_workspace_file, write_workspace_file, base_tools = [web_search, web_fetch, read_workspace_file, write_workspace_file,
home_assistant, connect_google_account, home_assistant, connect_google_account, list_google_accounts,
manage_service, schedule_reminder, list_reminders, cancel_reminder, manage_service, schedule_reminder, list_reminders, cancel_reminder,
list_calendars, get_calendar_events, list_gmail_messages, get_gmail_message] list_calendars, get_calendar_events, list_gmail_messages, get_gmail_message]

View File

@@ -4,21 +4,26 @@ Mirrors workspace-mcp (gcalendar/gmail) logic using google-api-python-client dir
since workspace-mcp tool functions require FastMCP request context and cannot be called since workspace-mcp tool functions require FastMCP request context and cannot be called
outside an MCP server. outside an MCP server.
Credential secret: agent-claw/google-credentials/{actor_id.replace(':', '-')} Credential secrets: agent-claw/google-credentials/{safe_actor_id}/{label}
Contains: token, refresh_token, token_uri, client_id, client_secret, scopes Backward compat: agent-claw/google-credentials/{safe_actor_id} (treated as "primary")
""" """
import json import json
import time
import traceback import traceback
import boto3 import boto3
import httplib2
from strands import tool from strands import tool
from google.oauth2.credentials import Credentials from google.oauth2.credentials import Credentials
from google.auth.transport.requests import Request from google.auth.transport.requests import Request
from googleapiclient.discovery import build from googleapiclient.discovery import build
from datetime import datetime, timezone, timedelta from datetime import datetime, timezone, timedelta
_HTTP_TIMEOUT = 15
_sm = None _sm = None
# Cache: actor_id -> (timestamp, {label: Credentials})
_creds_cache: dict[str, tuple[float, dict[str, Credentials]]] = {}
# Set per-invocation by main.py
_current_actor_id: str = ''
_current_google_accounts: dict = {} # {label: email} from DynamoDB
def _secrets(): def _secrets():
@@ -28,24 +33,24 @@ def _secrets():
return _sm return _sm
def _get_creds(actor_id: str) -> Credentials: def _actor_id():
from datetime import datetime return _current_actor_id
secret_name = 'agent-claw/google-credentials/' + actor_id.replace(':', '-')
print(f'[google] fetching creds actor={actor_id}')
data = json.loads(_secrets().get_secret_value(SecretId=secret_name)['SecretString']) def _load_creds_from_secret(secret_name: str) -> Credentials:
"""Load, optionally refresh, and return Credentials from a named secret."""
sm = _secrets()
data = json.loads(sm.get_secret_value(SecretId=secret_name)['SecretString'])
expiry_str = data.get('expiry') expiry_str = data.get('expiry')
expiry = None
if expiry_str: if expiry_str:
exp_aware = datetime.fromisoformat(expiry_str.replace('Z', '+00:00')) exp_aware = datetime.fromisoformat(expiry_str.replace('Z', '+00:00'))
expiry = exp_aware.replace(tzinfo=None) # google-auth uses naive UTC datetimes expiry = exp_aware.replace(tzinfo=None)
else:
expiry = None
stored_scopes = data.get('scopes', []) stored_scopes = data.get('scopes', [])
api_scopes = [s for s in stored_scopes if s.startswith('https://')] if stored_scopes else None api_scopes = [s for s in stored_scopes if s.startswith('https://')] if stored_scopes else None
# Fix stored scopes if they contain OIDC scopes
if stored_scopes and any(s in stored_scopes for s in ['openid', 'email', 'profile']): if stored_scopes and any(s in stored_scopes for s in ['openid', 'email', 'profile']):
data['scopes'] = api_scopes data['scopes'] = api_scopes
_secrets().put_secret_value(SecretId=secret_name, SecretString=json.dumps(data)) sm.put_secret_value(SecretId=secret_name, SecretString=json.dumps(data))
print('[google] fixed stored scopes: removed OIDC scopes')
creds = Credentials( creds = Credentials(
token=data.get('token'), token=data.get('token'),
refresh_token=data.get('refresh_token'), refresh_token=data.get('refresh_token'),
@@ -55,51 +60,96 @@ def _get_creds(actor_id: str) -> Credentials:
scopes=api_scopes, scopes=api_scopes,
expiry=expiry, expiry=expiry,
) )
print(f'[google] creds loaded, expired={creds.expired}')
if (creds.expired or not creds.valid) and creds.refresh_token: if (creds.expired or not creds.valid) and creds.refresh_token:
print('[google] refreshing token')
creds.refresh(Request()) creds.refresh(Request())
data['token'] = creds.token data['token'] = creds.token
if creds.expiry: if creds.expiry:
data['expiry'] = creds.expiry.isoformat() data['expiry'] = creds.expiry.isoformat()
_secrets().put_secret_value(SecretId=secret_name, SecretString=json.dumps(data)) sm.put_secret_value(SecretId=secret_name, SecretString=json.dumps(data))
print('[google] token refreshed and saved')
return creds return creds
def _actor_id(): def _load_all_creds(actor_id: str) -> dict[str, Credentials]:
# Read from module-level var set by main.py per invocation """Load all labeled credentials for actor_id, with 5-min TTL cache."""
# DO NOT use 'import main as _main' — it re-runs main.py including app.run() which hangs now = time.time()
return _current_actor_id if actor_id in _creds_cache:
ts, cached = _creds_cache[actor_id]
if now - ts < 300:
return cached
safe = actor_id.replace(':', '-').replace('/', '-')
prefix = f'agent-claw/google-credentials/{safe}/'
sm = _secrets()
result: dict[str, Credentials] = {}
# Set per-invocation by main.py before any tool call try:
_current_actor_id: str = '' paginator = sm.get_paginator('list_secrets')
for page in paginator.paginate(Filters=[{'Key': 'name', 'Values': [prefix]}]):
for secret in page.get('SecretList', []):
name = secret['Name']
label = name[len(prefix):]
if not label or '/' in label:
continue
try:
result[label] = _load_creds_from_secret(name)
print(f'[google] loaded creds actor={actor_id} label={label}')
except Exception as e:
print(f'[google] failed to load label={label}: {e}')
except Exception as e:
print(f'[google] list_secrets failed: {e}')
# Backward compat: flat secret path
if not result:
flat = f'agent-claw/google-credentials/{safe}'
try:
result['primary'] = _load_creds_from_secret(flat)
print(f'[google] loaded creds from flat path actor={actor_id}')
except Exception:
pass
_creds_cache[actor_id] = (now, result)
return result
def _svc(api: str, version: str, creds: Credentials): def _svc(api: str, version: str, creds: Credentials):
# Standard google-auth pattern: pass credentials= directly to build().
# google.oauth2.credentials.Credentials is natively supported by googleapiclient.
# (creds.authorize() is oauth2client only — not available here)
return build(api, version, credentials=creds, cache_discovery=False) return build(api, version, credentials=creds, cache_discovery=False)
def _get_creds_for_label(all_creds: dict[str, Credentials], label: str | None):
"""Return {label: creds} filtered by label, or all if label is None."""
if label:
if label not in all_creds:
return {}
return {label: all_creds[label]}
return all_creds
@tool @tool
def list_calendars() -> str: def list_calendars(account_label: str = None) -> str:
"""List all Google Calendars for the current user.""" """List all Google Calendars for the current user.
Args:
account_label: Optional account label (e.g. 'work', 'personal'). Lists all accounts if omitted.
"""
try: try:
creds = _get_creds(_actor_id()) all_creds = _load_all_creds(_actor_id())
result = _svc('calendar', 'v3', creds).calendarList().list().execute() if not all_creds:
items = result.get('items', []) return 'No Google accounts connected. Use connect_google_account to add one.'
if not items: creds_map = _get_creds_for_label(all_creds, account_label)
return 'No calendars found.' if not creds_map:
return '\n'.join( return f'No account with label "{account_label}" found.'
f'- "{c.get("summary", "")}"{" (Primary)" if c.get("primary") else ""} (ID: {c["id"]})' multi = len(creds_map) > 1
parts = []
for label, creds in creds_map.items():
items = _svc('calendar', 'v3', creds).calendarList().list().execute().get('items', [])
lines = [
f'{"[" + label + "] " if multi else ""}- "{c.get("summary", "")}"{" (Primary)" if c.get("primary") else ""} (ID: {c["id"]})'
for c in items for c in items
) ]
parts.append('\n'.join(lines) if lines else f'{"[" + label + "] " if multi else ""}No calendars found.')
return '\n'.join(parts)
except Exception as e: except Exception as e:
tb = traceback.format_exc() print(f'[google] list_calendars error: {e}\n{traceback.format_exc()}')
print(f'[google] list_calendars error: {e}\n{tb}')
return f'Error listing calendars: {e}' return f'Error listing calendars: {e}'
@@ -111,6 +161,7 @@ def get_calendar_events(
time_max: str = '', time_max: str = '',
max_results: int = 25, max_results: int = 25,
query: str = '', query: str = '',
account_label: str = None,
) -> str: ) -> str:
"""Get upcoming Google Calendar events. """Get upcoming Google Calendar events.
@@ -121,10 +172,16 @@ def get_calendar_events(
time_max: End of time range in RFC3339 format (optional) time_max: End of time range in RFC3339 format (optional)
max_results: Maximum events to return (default: 25) max_results: Maximum events to return (default: 25)
query: Keyword search within event fields (optional) query: Keyword search within event fields (optional)
account_label: Optional account label (e.g. 'work', 'personal'). Queries all accounts if omitted.
""" """
try: try:
creds = _get_creds(_actor_id()) all_creds = _load_all_creds(_actor_id())
svc = _svc('calendar', 'v3', creds) if not all_creds:
return 'No Google accounts connected.'
creds_map = _get_creds_for_label(all_creds, account_label)
if not creds_map:
return f'No account with label "{account_label}" found.'
multi = len(creds_map) > 1
now = datetime.now(timezone.utc) now = datetime.now(timezone.utc)
params = { params = {
'calendarId': calendar_id, 'calendarId': calendar_id,
@@ -136,37 +193,50 @@ def get_calendar_events(
} }
if query: if query:
params['q'] = query params['q'] = query
events = svc.events().list(**params).execute().get('items', []) parts = []
for label, creds in creds_map.items():
events = _svc('calendar', 'v3', creds).events().list(**params).execute().get('items', [])
if not events: if not events:
return f'No events found in calendar "{calendar_id}".' parts.append(f'{"[" + label + "] " if multi else ""}No events found in calendar "{calendar_id}".')
continue
lines = [] lines = []
for e in events: for e in events:
start = e['start'].get('dateTime', e['start'].get('date', '')) start = e['start'].get('dateTime', e['start'].get('date', ''))
end = e['end'].get('dateTime', e['end'].get('date', '')) end = e['end'].get('dateTime', e['end'].get('date', ''))
eid = e.get('id', '') prefix = f'[{label}] ' if multi else ''
lines.append(f'- "{e.get("summary", "No Title")}" (Starts: {start}, Ends: {end}) ID: {eid}') lines.append(f'{prefix}- "{e.get("summary", "No Title")}" (Starts: {start}, Ends: {end}) ID: {e.get("id", "")}')
return f'Retrieved {len(events)} events from "{calendar_id}":\n' + '\n'.join(lines) parts.append(f'Retrieved {len(events)} events{" [" + label + "]" if multi else ""} from "{calendar_id}":\n' + '\n'.join(lines))
return '\n\n'.join(parts)
except Exception as e: except Exception as e:
tb = traceback.format_exc() print(f'[google] get_calendar_events error: {e}\n{traceback.format_exc()}')
print(f'[google] get_calendar_events error: {e}\n{tb}')
return f'Error fetching calendar events: {e}' return f'Error fetching calendar events: {e}'
@tool @tool
def list_gmail_messages(max_results: int = 10, query: str = 'in:inbox') -> str: def list_gmail_messages(max_results: int = 10, query: str = 'in:inbox', account_label: str = None) -> str:
"""List Gmail messages. """List Gmail messages.
Args: Args:
max_results: Maximum number of messages to return (default: 10) max_results: Maximum number of messages to return (default: 10)
query: Gmail search query (default: 'in:inbox') query: Gmail search query (default: 'in:inbox')
account_label: Optional account label (e.g. 'work', 'personal'). Lists all accounts if omitted.
""" """
try: try:
creds = _get_creds(_actor_id()) all_creds = _load_all_creds(_actor_id())
if not all_creds:
return 'No Google accounts connected.'
creds_map = _get_creds_for_label(all_creds, account_label)
if not creds_map:
return f'No account with label "{account_label}" found.'
multi = len(creds_map) > 1
parts = []
for label, creds in creds_map.items():
svc = _svc('gmail', 'v1', creds) svc = _svc('gmail', 'v1', creds)
result = svc.users().messages().list(userId='me', q=query, maxResults=max_results).execute() result = svc.users().messages().list(userId='me', q=query, maxResults=max_results).execute()
messages = result.get('messages', []) messages = result.get('messages', [])
if not messages: if not messages:
return 'No messages found.' parts.append(f'{"[" + label + "] " if multi else ""}No messages found.')
continue
lines = [] lines = []
for m in messages: for m in messages:
msg = svc.users().messages().get( msg = svc.users().messages().get(
@@ -174,35 +244,42 @@ def list_gmail_messages(max_results: int = 10, query: str = 'in:inbox') -> str:
metadataHeaders=['Subject', 'From', 'Date'] metadataHeaders=['Subject', 'From', 'Date']
).execute() ).execute()
h = {hdr['name']: hdr['value'] for hdr in msg.get('payload', {}).get('headers', [])} h = {hdr['name']: hdr['value'] for hdr in msg.get('payload', {}).get('headers', [])}
lines.append(f"id={m['id']} | {h.get('Date', '')} | From: {h.get('From', '')} | {h.get('Subject', '(no subject)')}") prefix = f'[{label}] ' if multi else ''
next_token = result.get('nextPageToken') lines.append(f"{prefix}id={m['id']} | {h.get('Date', '')} | From: {h.get('From', '')} | {h.get('Subject', '(no subject)')}")
out = '\n'.join(lines) if result.get('nextPageToken'):
if next_token: lines.append(f'{"[" + label + "] " if multi else ""}(more results available)')
out += f'\n(more results available)' parts.append('\n'.join(lines))
return out return '\n'.join(parts)
except Exception as e: except Exception as e:
tb = traceback.format_exc() print(f'[google] list_gmail_messages error: {e}\n{traceback.format_exc()}')
print(f'[google] list_gmail_messages error: {e}\n{tb}')
return f'Error listing Gmail messages: {e}' return f'Error listing Gmail messages: {e}'
@tool @tool
def get_gmail_message(message_id: str, body_format: str = 'text') -> str: def get_gmail_message(message_id: str, body_format: str = 'text', account_label: str = None) -> str:
"""Get the full content of a Gmail message by ID. """Get the full content of a Gmail message by ID.
Args: Args:
message_id: The Gmail message ID message_id: The Gmail message ID
body_format: 'text' (default), 'html', or 'raw' body_format: 'text' (default), 'html', or 'raw'
account_label: Optional account label. Tries all accounts if omitted.
""" """
try: try:
creds = _get_creds(_actor_id()) all_creds = _load_all_creds(_actor_id())
if not all_creds:
return 'No Google accounts connected.'
creds_map = _get_creds_for_label(all_creds, account_label)
if not creds_map:
return f'No account with label "{account_label}" found.'
multi = len(creds_map) > 1
for label, creds in creds_map.items():
try:
svc = _svc('gmail', 'v1', creds) svc = _svc('gmail', 'v1', creds)
meta = svc.users().messages().get( meta = svc.users().messages().get(
userId='me', id=message_id, format='metadata', userId='me', id=message_id, format='metadata',
metadataHeaders=['Subject', 'From', 'To', 'Cc', 'Date'] metadataHeaders=['Subject', 'From', 'To', 'Cc', 'Date']
).execute() ).execute()
h = {hdr['name']: hdr['value'] for hdr in meta.get('payload', {}).get('headers', [])} h = {hdr['name']: hdr['value'] for hdr in meta.get('payload', {}).get('headers', [])}
if body_format == 'raw': if body_format == 'raw':
import base64 import base64
raw = svc.users().messages().get(userId='me', id=message_id, format='raw').execute() raw = svc.users().messages().get(userId='me', id=message_id, format='raw').execute()
@@ -210,9 +287,9 @@ def get_gmail_message(message_id: str, body_format: str = 'text') -> str:
else: else:
full = svc.users().messages().get(userId='me', id=message_id, format='full').execute() full = svc.users().messages().get(userId='me', id=message_id, format='full').execute()
body = _extract_body(full.get('payload', {}), prefer_html=(body_format == 'html')) body = _extract_body(full.get('payload', {}), prefer_html=(body_format == 'html'))
prefix = f'[{label}]\n' if multi else ''
lines = [ lines = [
f"From: {h.get('From', '')}", f"{prefix}From: {h.get('From', '')}",
f"To: {h.get('To', '')}", f"To: {h.get('To', '')}",
f"Date: {h.get('Date', '')}", f"Date: {h.get('Date', '')}",
f"Subject: {h.get('Subject', '')}", f"Subject: {h.get('Subject', '')}",
@@ -222,6 +299,12 @@ def get_gmail_message(message_id: str, body_format: str = 'text') -> str:
if h.get('Cc'): if h.get('Cc'):
lines.insert(3, f"Cc: {h['Cc']}") lines.insert(3, f"Cc: {h['Cc']}")
return '\n'.join(lines) return '\n'.join(lines)
except Exception as e:
if multi:
print(f'[google] get_gmail_message label={label} not found: {e}')
continue
return f'Error fetching Gmail message: {e}'
return f'Message {message_id} not found in any connected account.'
except Exception as e: except Exception as e:
return f'Error fetching Gmail message: {e}' return f'Error fetching Gmail message: {e}'
@@ -237,17 +320,14 @@ def _extract_body(payload: dict, prefer_html: bool = False) -> str:
return base64.urlsafe_b64decode(data + '==').decode('utf-8', errors='replace') if data else '' return base64.urlsafe_b64decode(data + '==').decode('utf-8', errors='replace') if data else ''
parts = payload.get('parts', []) parts = payload.get('parts', [])
# First pass: preferred type
for part in parts: for part in parts:
if part.get('mimeType') == target: if part.get('mimeType') == target:
data = part.get('body', {}).get('data', '') data = part.get('body', {}).get('data', '')
return base64.urlsafe_b64decode(data + '==').decode('utf-8', errors='replace') if data else '' return base64.urlsafe_b64decode(data + '==').decode('utf-8', errors='replace') if data else ''
# Second pass: fallback type
for part in parts: for part in parts:
if part.get('mimeType') == fallback: if part.get('mimeType') == fallback:
data = part.get('body', {}).get('data', '') data = part.get('body', {}).get('data', '')
return base64.urlsafe_b64decode(data + '==').decode('utf-8', errors='replace') if data else '' return base64.urlsafe_b64decode(data + '==').decode('utf-8', errors='replace') if data else ''
# Recurse into multipart
for part in parts: for part in parts:
text = _extract_body(part, prefer_html) text = _extract_body(part, prefer_html)
if text: if text:

View File

@@ -185,7 +185,7 @@ def handler(event, context):
'user_profile': { 'user_profile': {
'display_name': user_profile.get('display_name', actor_id), 'display_name': user_profile.get('display_name', actor_id),
'telegram_username': user_profile.get('telegram_username', ''), 'telegram_username': user_profile.get('telegram_username', ''),
'google_email': user_profile.get('google_email', ''), 'google_accounts': user_profile.get('google_accounts', {'primary': user_profile['google_email']} if user_profile.get('google_email') else {}),
'allowed': user_profile.get('allowed', True), 'allowed': user_profile.get('allowed', True),
'services': user_profile.get('enrolled_services', user_profile.get('services', {})), 'services': user_profile.get('enrolled_services', user_profile.get('services', {})),
}, },

View File

@@ -2,12 +2,10 @@
Google OAuth handler Lambda. Google OAuth handler Lambda.
Routes: Routes:
GET /oauth/start?actor_id=telegram:123 → redirect to Google OAuth consent GET /oauth/start?actor_id=telegram:123&label=work → redirect to Google OAuth consent
GET /oauth/callback?code=...&state=... → exchange code, store tokens, update DynamoDB GET /oauth/callback?code=...&state=... → exchange code, store tokens, update DynamoDB
""" """
import base64 import base64
import hashlib
import hmac
import json import json
import os import os
import time import time
@@ -52,9 +50,9 @@ def get_oauth_client() -> tuple[str, str]:
return secret['client_id'], secret['client_secret'] return secret['client_id'], secret['client_secret']
def actor_id_to_secret_name(actor_id: str) -> str: def actor_id_to_secret_name(actor_id: str, label: str = 'primary') -> str:
safe = actor_id.replace(':', '-').replace('/', '-') safe = actor_id.replace(':', '-').replace('/', '-')
return f'agent-claw/google-credentials/{safe}' return f'agent-claw/google-credentials/{safe}/{label}'
def _redirect(url: str) -> dict: def _redirect(url: str) -> dict:
@@ -82,11 +80,14 @@ def handle_start(params: dict) -> dict:
if not actor_id: if not actor_id:
return _html('<h1>Missing actor_id</h1>', 400) return _html('<h1>Missing actor_id</h1>', 400)
label = params.get('label', 'primary')
client_id, _ = get_oauth_client() client_id, _ = get_oauth_client()
redirect_uri = os.environ['OAUTH_REDIRECT_URI'] redirect_uri = os.environ['OAUTH_REDIRECT_URI']
# Encode actor_id in state (base64 to keep URL-safe) # Encode actor_id + label in state (JSON → base64)
state = base64.urlsafe_b64encode(actor_id.encode()).decode().rstrip('=') state_data = json.dumps({'a': actor_id, 'l': label})
state = base64.urlsafe_b64encode(state_data.encode()).decode().rstrip('=')
auth_url = ( auth_url = (
'https://accounts.google.com/o/oauth2/v2/auth?' 'https://accounts.google.com/o/oauth2/v2/auth?'
@@ -113,10 +114,18 @@ def handle_callback(params: dict) -> dict:
if not code or not state: if not code or not state:
return _html('<h1>Missing code or state</h1>', 400) return _html('<h1>Missing code or state</h1>', 400)
# Decode actor_id from state # Decode actor_id + label from state
try:
padding = 4 - len(state) % 4
state_data = json.loads(base64.urlsafe_b64decode(state + '=' * padding).decode())
actor_id = state_data['a']
label = state_data.get('l', 'primary')
except Exception:
# Backward compat: old state was just base64(actor_id)
try: try:
padding = 4 - len(state) % 4 padding = 4 - len(state) % 4
actor_id = base64.urlsafe_b64decode(state + '=' * padding).decode() actor_id = base64.urlsafe_b64decode(state + '=' * padding).decode()
label = 'primary'
except Exception: except Exception:
return _html('<h1>Invalid state</h1>', 400) return _html('<h1>Invalid state</h1>', 400)
@@ -155,7 +164,6 @@ def handle_callback(params: dict) -> dict:
pass pass
if not user_email: if not user_email:
# Fallback: call userinfo endpoint
try: try:
access_token = tokens.get('access_token', '') access_token = tokens.get('access_token', '')
req2 = urllib.request.Request( req2 = urllib.request.Request(
@@ -184,23 +192,24 @@ def handle_callback(params: dict) -> dict:
time.gmtime(time.time() + int(tokens['expires_in'])) time.gmtime(time.time() + int(tokens['expires_in']))
) )
# Store in Secrets Manager # Store in Secrets Manager at labeled path
secret_name = actor_id_to_secret_name(actor_id) secret_name = actor_id_to_secret_name(actor_id, label)
sm = get_sm() sm = get_sm()
try: try:
sm.create_secret(Name=secret_name, SecretString=json.dumps(creds)) sm.create_secret(Name=secret_name, SecretString=json.dumps(creds))
except sm.exceptions.ResourceExistsException: except sm.exceptions.ResourceExistsException:
sm.put_secret_value(SecretId=secret_name, SecretString=json.dumps(creds)) sm.put_secret_value(SecretId=secret_name, SecretString=json.dumps(creds))
print(f'[oauth] Stored credentials for actor={actor_id} email={user_email}') print(f'[oauth] Stored credentials for actor={actor_id} label={label} email={user_email}')
# Update DynamoDB users table with google_email # Update DynamoDB: merge into google_accounts map
table_name = os.environ.get('USERS_TABLE_NAME', '') table_name = os.environ.get('USERS_TABLE_NAME', '')
if table_name and actor_id: if table_name and actor_id:
try: try:
get_ddb().Table(table_name).update_item( get_ddb().Table(table_name).update_item(
Key={'actor_id': actor_id}, Key={'actor_id': actor_id},
UpdateExpression='SET google_email = :e', UpdateExpression='SET google_accounts = if_not_exists(google_accounts, :empty), google_accounts.#label = :email',
ExpressionAttributeValues={':e': user_email}, ExpressionAttributeNames={'#label': label},
ExpressionAttributeValues={':email': user_email, ':empty': {}},
) )
except Exception as e: except Exception as e:
print(f'[oauth] DynamoDB update failed: {e}') print(f'[oauth] DynamoDB update failed: {e}')
@@ -211,10 +220,7 @@ def handle_callback(params: dict) -> dict:
if bot_token_arn and actor_id.startswith('telegram:'): if bot_token_arn and actor_id.startswith('telegram:'):
chat_id = actor_id.split(':', 1)[1] chat_id = actor_id.split(':', 1)[1]
bot_token = get_sm().get_secret_value(SecretId=bot_token_arn)['SecretString'] bot_token = get_sm().get_secret_value(SecretId=bot_token_arn)['SecretString']
tg_text = ( tg_text = f'✅ Connected {user_email} as "{label}"'
f'✅ Google account connected!\n\n'
f'{user_email} is now linked. You can now ask me about your Gmail, Calendar, and Drive.'
)
tg_payload = json.dumps({'chat_id': chat_id, 'text': tg_text}).encode() tg_payload = json.dumps({'chat_id': chat_id, 'text': tg_text}).encode()
tg_req = urllib.request.Request( tg_req = urllib.request.Request(
f'https://api.telegram.org/bot{bot_token}/sendMessage', f'https://api.telegram.org/bot{bot_token}/sendMessage',
@@ -227,6 +233,6 @@ def handle_callback(params: dict) -> dict:
return _html( return _html(
f'<h1>✅ Google account connected!</h1>' f'<h1>✅ Google account connected!</h1>'
f'<p>Connected <b>{user_email}</b> to your agent account.</p>' f'<p>Connected <b>{user_email}</b> as "<b>{label}</b>".</p>'
f'<p>You can close this window and return to Telegram.</p>' f'<p>You can close this window and return to Telegram.</p>'
) )