Files
agent-claw/agentclaw/app/agent_claw_main/memory_manager.py
daniel 3a34e61479 feat: add windowed session history + LTM extraction/retrieval
- New memory_manager.py with:
  - check_and_compact: runs compaction on flagged sessions (extracts LTM via
    Claude Haiku, stores as AgentCore Memory event, deletes old events)
  - check_window_and_flag: sets DynamoDB flag when session > 100 events
  - load_ltm: retrieves LTM extractions and formats as system prompt block
- Wired into main.py:
  - Compaction runs before session_manager creation (trims old events)
  - LTM block injected into system prompt
  - Window check runs after session close
- SESSION_WINDOW_SIZE = 100 (named constant)
- Compaction is idempotent (uses event timestamps as cursor)
- LTM retrieval failure is non-fatal (logs and continues)
2026-05-13 11:57:50 -05:00

321 lines
12 KiB
Python

"""Long-term memory manager: windowed loading, compaction, and LTM retrieval."""
import json
import logging
import os
from datetime import datetime, timezone
from typing import Any
import boto3
from bedrock_agentcore.memory.client import MemoryClient
logger = logging.getLogger(__name__)
MEMORY_ID = 'agentclaw_AgentClawMemory-i7Csf776AH'
SESSION_WINDOW_SIZE = 100
USERS_TABLE_NAME = os.environ.get('USERS_TABLE_NAME', 'agent-claw-users')
LTM_SESSION_ID = 'ltm-extractions'
HAIKU_MODEL_ID = 'us.anthropic.claude-3-5-haiku-20241022-v1:0'
_memory_client: MemoryClient | None = None
def _get_memory_client() -> MemoryClient:
global _memory_client
if _memory_client is None:
_memory_client = MemoryClient(region_name='us-east-1')
return _memory_client
def _get_compaction_flag(actor_id: str) -> bool:
"""Check if compaction is needed for this actor."""
ddb = boto3.resource('dynamodb', region_name='us-east-1')
table = ddb.Table(USERS_TABLE_NAME)
resp = table.get_item(Key={'actor_id': actor_id})
return resp.get('Item', {}).get('needs_compaction', False)
def _set_compaction_flag(actor_id: str, value: bool) -> None:
"""Set or clear the compaction flag."""
ddb = boto3.resource('dynamodb', region_name='us-east-1')
table = ddb.Table(USERS_TABLE_NAME)
table.update_item(
Key={'actor_id': actor_id},
UpdateExpression='SET needs_compaction = :v',
ExpressionAttributeValues={':v': value},
)
def _count_session_events(actor_id: str, session_id: str) -> int:
"""Count total events in the session (excluding state/agent metadata events)."""
client = _get_memory_client()
events = client.list_events(
memory_id=MEMORY_ID,
actor_id=actor_id,
session_id=session_id,
max_results=10000,
include_payload=False,
)
# Filter out session/agent state events (they have stateType metadata)
return sum(1 for e in events if not e.get('metadata', {}).get('stateType'))
def _get_all_session_events(actor_id: str, session_id: str) -> list[dict]:
"""Get all conversation events (excluding state metadata events)."""
client = _get_memory_client()
events = client.list_events(
memory_id=MEMORY_ID,
actor_id=actor_id,
session_id=session_id,
max_results=10000,
include_payload=True,
)
return [e for e in events if not e.get('metadata', {}).get('stateType')]
def _extract_text_from_events(events: list[dict]) -> str:
"""Extract conversation text from events for summarization."""
lines = []
for event in events:
for item in event.get('payload', []):
if 'conversational' in item:
conv = item['conversational']
role = conv.get('role', 'UNKNOWN')
text = conv.get('content', {}).get('text', '')
lines.append(f'{role}: {text}')
elif 'blob' in item:
try:
blob = json.loads(item['blob']) if isinstance(item['blob'], str) else item['blob']
if isinstance(blob, list) and blob:
for msg in blob:
if isinstance(msg, (list, tuple)) and len(msg) == 2:
lines.append(f'{msg[1]}: {msg[0]}')
except (json.JSONDecodeError, TypeError):
pass
return '\n'.join(lines[-200:]) # Cap at last 200 lines to stay within context
def _call_claude_extraction(conversation_text: str) -> dict:
"""Call Claude Haiku to extract structured LTM from conversation text."""
bedrock = boto3.client('bedrock-runtime', region_name='us-east-1')
prompt = (
'Extract structured long-term memory from this conversation. '
'Return ONLY valid JSON with these keys:\n'
'- "summary": 3-5 sentence narrative of what was discussed\n'
'- "facts": array of factual statements worth remembering\n'
'- "preferences": array of user preferences expressed\n'
'- "dates": array of events/deadlines with date/time mentioned\n'
'- "topics": array of topic keywords\n\n'
'Conversation:\n' + conversation_text
)
resp = bedrock.converse(
modelId=HAIKU_MODEL_ID,
messages=[{'role': 'user', 'content': [{'text': prompt}]}],
inferenceConfig={'maxTokens': 1024},
)
text = resp['output']['message']['content'][0]['text']
# Parse JSON from response (handle markdown code blocks)
if '```' in text:
text = text.split('```')[1]
if text.startswith('json'):
text = text[4:]
return json.loads(text.strip())
def _get_last_compaction_timestamp(actor_id: str) -> str | None:
"""Get the timestamp of the most recent LTM extraction to avoid duplicates."""
client = _get_memory_client()
events = client.list_events(
memory_id=MEMORY_ID,
actor_id=actor_id,
session_id=LTM_SESSION_ID,
event_metadata=[{
'left': {'metadataKey': 'type'},
'operator': 'EQUALS_TO',
'right': {'metadataValue': {'stringValue': 'ltm_extraction'}},
}],
max_results=1,
include_payload=False,
)
if events:
# Events are returned chronologically; last one is most recent
return str(events[-1].get('eventTimestamp', ''))
return None
def check_and_compact(actor_id: str, session_id: str) -> None:
"""Run compaction if the flag is set. Call BEFORE creating session_manager."""
if not _get_compaction_flag(actor_id):
return
logger.info('[memory_manager] Compaction triggered for actor_id=%s', actor_id)
try:
events = _get_all_session_events(actor_id, session_id)
total = len(events)
if total <= SESSION_WINDOW_SIZE:
_set_compaction_flag(actor_id, False)
return
# Events to compact: everything before the window
compact_count = total - SESSION_WINDOW_SIZE
events_to_compact = events[:compact_count]
# Idempotency: check if we already compacted up to this timestamp
last_compacted = _get_last_compaction_timestamp(actor_id)
oldest_event_ts = str(events_to_compact[-1].get('eventTimestamp', ''))
if last_compacted and last_compacted >= oldest_event_ts:
logger.info('[memory_manager] Already compacted up to %s, skipping', last_compacted)
_set_compaction_flag(actor_id, False)
return
# Extract text and call Claude
text = _extract_text_from_events(events_to_compact)
if not text.strip():
logger.info('[memory_manager] No text to compact, clearing flag')
_set_compaction_flag(actor_id, False)
return
extraction = _call_claude_extraction(text)
# Store LTM extraction as an event
client = _get_memory_client()
client.create_event(
memory_id=MEMORY_ID,
actor_id=actor_id,
session_id=LTM_SESSION_ID,
messages=[(json.dumps(extraction), 'ASSISTANT')],
event_timestamp=datetime.now(timezone.utc),
metadata={
'type': {'stringValue': 'ltm_extraction'},
'actor_id': {'stringValue': actor_id},
'compacted_through': {'stringValue': oldest_event_ts},
},
)
# Delete compacted events from the session
for event in events_to_compact:
try:
client.gmdp_client.delete_event(
memoryId=MEMORY_ID,
actorId=actor_id,
sessionId=session_id,
eventId=event['eventId'],
)
except Exception as e:
logger.warning('[memory_manager] Failed to delete event %s: %s', event.get('eventId'), e)
_set_compaction_flag(actor_id, False)
logger.info('[memory_manager] Compacted %d events into LTM for actor_id=%s', compact_count, actor_id)
except Exception as e:
logger.error('[memory_manager] Compaction failed: %s', e)
# Don't clear flag — retry next invocation
def check_window_and_flag(actor_id: str, session_id: str) -> None:
"""After session loads, check if we exceed the window and set flag for next time."""
try:
count = _count_session_events(actor_id, session_id)
if count > SESSION_WINDOW_SIZE:
logger.info('[memory_manager] Session has %d events (> %d), setting compaction flag',
count, SESSION_WINDOW_SIZE)
_set_compaction_flag(actor_id, True)
except Exception as e:
logger.error('[memory_manager] Failed to check window: %s', e)
def load_ltm(actor_id: str) -> str:
"""Load all LTM extractions for an actor and format as a system prompt block.
Returns empty string on failure (non-fatal).
"""
try:
client = _get_memory_client()
events = client.list_events(
memory_id=MEMORY_ID,
actor_id=actor_id,
session_id=LTM_SESSION_ID,
event_metadata=[{
'left': {'metadataKey': 'type'},
'operator': 'EQUALS_TO',
'right': {'metadataValue': {'stringValue': 'ltm_extraction'}},
}],
max_results=50,
include_payload=True,
)
if not events:
return ''
# Parse extractions (events are chronological, reverse for most-recent-first)
extractions = []
for event in reversed(events):
for item in event.get('payload', []):
if 'conversational' in item:
text = item['conversational'].get('content', {}).get('text', '')
try:
extractions.append(json.loads(text))
except (json.JSONDecodeError, TypeError):
pass
if not extractions:
return ''
# Build the LTM block: most recent summary first, then deduplicated lists
parts = ['## Long-term memory\n']
# Most recent summary
if extractions[0].get('summary'):
parts.append(f'**Recent context:** {extractions[0]["summary"]}\n')
# Deduplicated facts
all_facts = []
seen_facts: set[str] = set()
for ext in extractions:
for f in ext.get('facts', []):
key = f.lower().strip()
if key not in seen_facts:
seen_facts.add(key)
all_facts.append(f)
if all_facts:
parts.append('**Facts:**')
for f in all_facts[:30]: # Cap at 30
parts.append(f'- {f}')
parts.append('')
# Deduplicated preferences
all_prefs = []
seen_prefs: set[str] = set()
for ext in extractions:
for p in ext.get('preferences', []):
key = p.lower().strip()
if key not in seen_prefs:
seen_prefs.add(key)
all_prefs.append(p)
if all_prefs:
parts.append('**Preferences:**')
for p in all_prefs[:15]:
parts.append(f'- {p}')
parts.append('')
# Dates (most recent extractions first, keep all)
all_dates = []
for ext in extractions:
all_dates.extend(ext.get('dates', []))
if all_dates:
parts.append('**Upcoming dates/events:**')
for d in all_dates[:10]:
parts.append(f'- {d}')
parts.append('')
result = '\n'.join(parts)
logger.info('[memory_manager] Loaded LTM block: %d chars from %d extractions', len(result), len(extractions))
return result
except Exception as e:
logger.error('[memory_manager] LTM retrieval failed (non-fatal): %s', e)
return ''