diff --git a/agentclaw/app/agent_claw_main/main.py b/agentclaw/app/agent_claw_main/main.py index 3cbb0f1..4bfb68b 100644 --- a/agentclaw/app/agent_claw_main/main.py +++ b/agentclaw/app/agent_claw_main/main.py @@ -10,6 +10,7 @@ from bedrock_agentcore.runtime import BedrockAgentCoreApp from channels.telegram import TelegramAdapter from prompt_builder import build_system_prompt, invalidate_prompt +import memory_manager from tools import web as web_tools from tools import workspace as ws_tools from tools import messaging @@ -313,6 +314,9 @@ async def main(payload: dict, context): _scheduler_module._current_actor_id = actor_id _scheduler_module._current_chat_id = chat_id + # Run compaction if flagged from previous invocation (trims old events before load) + memory_manager.check_and_compact(actor_id, session_id) + memory_config = AgentCoreMemoryConfig( memory_id=MEMORY_ID, session_id=session_id, @@ -352,6 +356,11 @@ async def main(payload: dict, context): user_context += f'\nEnrolled services: {", ".join(enrolled)}' system_prompt = build_system_prompt(user_context=user_context, actor_id=actor_id) + # Inject long-term memory block before conversation history + ltm_block = memory_manager.load_ltm(actor_id) + if ltm_block: + system_prompt = system_prompt + '\n\n---\n\n' + ltm_block + # Inject current datetime so the model always has accurate time context from datetime import datetime from zoneinfo import ZoneInfo @@ -398,6 +407,8 @@ async def main(payload: dict, context): finally: _typing_active = False session_manager.close() + # Check if session exceeds window — flag for compaction on next invocation + memory_manager.check_window_and_flag(actor_id, session_id) app.run() diff --git a/agentclaw/app/agent_claw_main/memory_manager.py b/agentclaw/app/agent_claw_main/memory_manager.py new file mode 100644 index 0000000..9dd7d2c --- /dev/null +++ b/agentclaw/app/agent_claw_main/memory_manager.py @@ -0,0 +1,320 @@ +"""Long-term memory manager: windowed loading, compaction, and LTM retrieval.""" + +import json +import logging +import os +from datetime import datetime, timezone +from typing import Any + +import boto3 + +from bedrock_agentcore.memory.client import MemoryClient + +logger = logging.getLogger(__name__) + +MEMORY_ID = 'agentclaw_AgentClawMemory-i7Csf776AH' +SESSION_WINDOW_SIZE = 100 +USERS_TABLE_NAME = os.environ.get('USERS_TABLE_NAME', 'agent-claw-users') +LTM_SESSION_ID = 'ltm-extractions' +HAIKU_MODEL_ID = 'us.anthropic.claude-3-5-haiku-20241022-v1:0' + +_memory_client: MemoryClient | None = None + + +def _get_memory_client() -> MemoryClient: + global _memory_client + if _memory_client is None: + _memory_client = MemoryClient(region_name='us-east-1') + return _memory_client + + +def _get_compaction_flag(actor_id: str) -> bool: + """Check if compaction is needed for this actor.""" + ddb = boto3.resource('dynamodb', region_name='us-east-1') + table = ddb.Table(USERS_TABLE_NAME) + resp = table.get_item(Key={'actor_id': actor_id}) + return resp.get('Item', {}).get('needs_compaction', False) + + +def _set_compaction_flag(actor_id: str, value: bool) -> None: + """Set or clear the compaction flag.""" + ddb = boto3.resource('dynamodb', region_name='us-east-1') + table = ddb.Table(USERS_TABLE_NAME) + table.update_item( + Key={'actor_id': actor_id}, + UpdateExpression='SET needs_compaction = :v', + ExpressionAttributeValues={':v': value}, + ) + + +def _count_session_events(actor_id: str, session_id: str) -> int: + """Count total events in the session (excluding state/agent metadata events).""" + client = _get_memory_client() + events = client.list_events( + memory_id=MEMORY_ID, + actor_id=actor_id, + session_id=session_id, + max_results=10000, + include_payload=False, + ) + # Filter out session/agent state events (they have stateType metadata) + return sum(1 for e in events if not e.get('metadata', {}).get('stateType')) + + +def _get_all_session_events(actor_id: str, session_id: str) -> list[dict]: + """Get all conversation events (excluding state metadata events).""" + client = _get_memory_client() + events = client.list_events( + memory_id=MEMORY_ID, + actor_id=actor_id, + session_id=session_id, + max_results=10000, + include_payload=True, + ) + return [e for e in events if not e.get('metadata', {}).get('stateType')] + + +def _extract_text_from_events(events: list[dict]) -> str: + """Extract conversation text from events for summarization.""" + lines = [] + for event in events: + for item in event.get('payload', []): + if 'conversational' in item: + conv = item['conversational'] + role = conv.get('role', 'UNKNOWN') + text = conv.get('content', {}).get('text', '') + lines.append(f'{role}: {text}') + elif 'blob' in item: + try: + blob = json.loads(item['blob']) if isinstance(item['blob'], str) else item['blob'] + if isinstance(blob, list) and blob: + for msg in blob: + if isinstance(msg, (list, tuple)) and len(msg) == 2: + lines.append(f'{msg[1]}: {msg[0]}') + except (json.JSONDecodeError, TypeError): + pass + return '\n'.join(lines[-200:]) # Cap at last 200 lines to stay within context + + +def _call_claude_extraction(conversation_text: str) -> dict: + """Call Claude Haiku to extract structured LTM from conversation text.""" + bedrock = boto3.client('bedrock-runtime', region_name='us-east-1') + prompt = ( + 'Extract structured long-term memory from this conversation. ' + 'Return ONLY valid JSON with these keys:\n' + '- "summary": 3-5 sentence narrative of what was discussed\n' + '- "facts": array of factual statements worth remembering\n' + '- "preferences": array of user preferences expressed\n' + '- "dates": array of events/deadlines with date/time mentioned\n' + '- "topics": array of topic keywords\n\n' + 'Conversation:\n' + conversation_text + ) + resp = bedrock.converse( + modelId=HAIKU_MODEL_ID, + messages=[{'role': 'user', 'content': [{'text': prompt}]}], + inferenceConfig={'maxTokens': 1024}, + ) + text = resp['output']['message']['content'][0]['text'] + # Parse JSON from response (handle markdown code blocks) + if '```' in text: + text = text.split('```')[1] + if text.startswith('json'): + text = text[4:] + return json.loads(text.strip()) + + +def _get_last_compaction_timestamp(actor_id: str) -> str | None: + """Get the timestamp of the most recent LTM extraction to avoid duplicates.""" + client = _get_memory_client() + events = client.list_events( + memory_id=MEMORY_ID, + actor_id=actor_id, + session_id=LTM_SESSION_ID, + event_metadata=[{ + 'left': {'metadataKey': 'type'}, + 'operator': 'EQUALS_TO', + 'right': {'metadataValue': {'stringValue': 'ltm_extraction'}}, + }], + max_results=1, + include_payload=False, + ) + if events: + # Events are returned chronologically; last one is most recent + return str(events[-1].get('eventTimestamp', '')) + return None + + +def check_and_compact(actor_id: str, session_id: str) -> None: + """Run compaction if the flag is set. Call BEFORE creating session_manager.""" + if not _get_compaction_flag(actor_id): + return + + logger.info('[memory_manager] Compaction triggered for actor_id=%s', actor_id) + + try: + events = _get_all_session_events(actor_id, session_id) + total = len(events) + + if total <= SESSION_WINDOW_SIZE: + _set_compaction_flag(actor_id, False) + return + + # Events to compact: everything before the window + compact_count = total - SESSION_WINDOW_SIZE + events_to_compact = events[:compact_count] + + # Idempotency: check if we already compacted up to this timestamp + last_compacted = _get_last_compaction_timestamp(actor_id) + oldest_event_ts = str(events_to_compact[-1].get('eventTimestamp', '')) + if last_compacted and last_compacted >= oldest_event_ts: + logger.info('[memory_manager] Already compacted up to %s, skipping', last_compacted) + _set_compaction_flag(actor_id, False) + return + + # Extract text and call Claude + text = _extract_text_from_events(events_to_compact) + if not text.strip(): + logger.info('[memory_manager] No text to compact, clearing flag') + _set_compaction_flag(actor_id, False) + return + + extraction = _call_claude_extraction(text) + + # Store LTM extraction as an event + client = _get_memory_client() + client.create_event( + memory_id=MEMORY_ID, + actor_id=actor_id, + session_id=LTM_SESSION_ID, + messages=[(json.dumps(extraction), 'ASSISTANT')], + event_timestamp=datetime.now(timezone.utc), + metadata={ + 'type': {'stringValue': 'ltm_extraction'}, + 'actor_id': {'stringValue': actor_id}, + 'compacted_through': {'stringValue': oldest_event_ts}, + }, + ) + + # Delete compacted events from the session + for event in events_to_compact: + try: + client.gmdp_client.delete_event( + memoryId=MEMORY_ID, + actorId=actor_id, + sessionId=session_id, + eventId=event['eventId'], + ) + except Exception as e: + logger.warning('[memory_manager] Failed to delete event %s: %s', event.get('eventId'), e) + + _set_compaction_flag(actor_id, False) + logger.info('[memory_manager] Compacted %d events into LTM for actor_id=%s', compact_count, actor_id) + + except Exception as e: + logger.error('[memory_manager] Compaction failed: %s', e) + # Don't clear flag — retry next invocation + + +def check_window_and_flag(actor_id: str, session_id: str) -> None: + """After session loads, check if we exceed the window and set flag for next time.""" + try: + count = _count_session_events(actor_id, session_id) + if count > SESSION_WINDOW_SIZE: + logger.info('[memory_manager] Session has %d events (> %d), setting compaction flag', + count, SESSION_WINDOW_SIZE) + _set_compaction_flag(actor_id, True) + except Exception as e: + logger.error('[memory_manager] Failed to check window: %s', e) + + +def load_ltm(actor_id: str) -> str: + """Load all LTM extractions for an actor and format as a system prompt block. + + Returns empty string on failure (non-fatal). + """ + try: + client = _get_memory_client() + events = client.list_events( + memory_id=MEMORY_ID, + actor_id=actor_id, + session_id=LTM_SESSION_ID, + event_metadata=[{ + 'left': {'metadataKey': 'type'}, + 'operator': 'EQUALS_TO', + 'right': {'metadataValue': {'stringValue': 'ltm_extraction'}}, + }], + max_results=50, + include_payload=True, + ) + + if not events: + return '' + + # Parse extractions (events are chronological, reverse for most-recent-first) + extractions = [] + for event in reversed(events): + for item in event.get('payload', []): + if 'conversational' in item: + text = item['conversational'].get('content', {}).get('text', '') + try: + extractions.append(json.loads(text)) + except (json.JSONDecodeError, TypeError): + pass + + if not extractions: + return '' + + # Build the LTM block: most recent summary first, then deduplicated lists + parts = ['## Long-term memory\n'] + + # Most recent summary + if extractions[0].get('summary'): + parts.append(f'**Recent context:** {extractions[0]["summary"]}\n') + + # Deduplicated facts + all_facts = [] + seen_facts: set[str] = set() + for ext in extractions: + for f in ext.get('facts', []): + key = f.lower().strip() + if key not in seen_facts: + seen_facts.add(key) + all_facts.append(f) + if all_facts: + parts.append('**Facts:**') + for f in all_facts[:30]: # Cap at 30 + parts.append(f'- {f}') + parts.append('') + + # Deduplicated preferences + all_prefs = [] + seen_prefs: set[str] = set() + for ext in extractions: + for p in ext.get('preferences', []): + key = p.lower().strip() + if key not in seen_prefs: + seen_prefs.add(key) + all_prefs.append(p) + if all_prefs: + parts.append('**Preferences:**') + for p in all_prefs[:15]: + parts.append(f'- {p}') + parts.append('') + + # Dates (most recent extractions first, keep all) + all_dates = [] + for ext in extractions: + all_dates.extend(ext.get('dates', [])) + if all_dates: + parts.append('**Upcoming dates/events:**') + for d in all_dates[:10]: + parts.append(f'- {d}') + parts.append('') + + result = '\n'.join(parts) + logger.info('[memory_manager] Loaded LTM block: %d chars from %d extractions', len(result), len(extractions)) + return result + + except Exception as e: + logger.error('[memory_manager] LTM retrieval failed (non-fatal): %s', e) + return ''