feat: add windowed session history + LTM extraction/retrieval
- New memory_manager.py with:
- check_and_compact: runs compaction on flagged sessions (extracts LTM via
Claude Haiku, stores as AgentCore Memory event, deletes old events)
- check_window_and_flag: sets DynamoDB flag when session > 100 events
- load_ltm: retrieves LTM extractions and formats as system prompt block
- Wired into main.py:
- Compaction runs before session_manager creation (trims old events)
- LTM block injected into system prompt
- Window check runs after session close
- SESSION_WINDOW_SIZE = 100 (named constant)
- Compaction is idempotent (uses event timestamps as cursor)
- LTM retrieval failure is non-fatal (logs and continues)
This commit is contained in:
320
agentclaw/app/agent_claw_main/memory_manager.py
Normal file
320
agentclaw/app/agent_claw_main/memory_manager.py
Normal file
@@ -0,0 +1,320 @@
|
||||
"""Long-term memory manager: windowed loading, compaction, and LTM retrieval."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
import boto3
|
||||
|
||||
from bedrock_agentcore.memory.client import MemoryClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MEMORY_ID = 'agentclaw_AgentClawMemory-i7Csf776AH'
|
||||
SESSION_WINDOW_SIZE = 100
|
||||
USERS_TABLE_NAME = os.environ.get('USERS_TABLE_NAME', 'agent-claw-users')
|
||||
LTM_SESSION_ID = 'ltm-extractions'
|
||||
HAIKU_MODEL_ID = 'us.anthropic.claude-3-5-haiku-20241022-v1:0'
|
||||
|
||||
_memory_client: MemoryClient | None = None
|
||||
|
||||
|
||||
def _get_memory_client() -> MemoryClient:
|
||||
global _memory_client
|
||||
if _memory_client is None:
|
||||
_memory_client = MemoryClient(region_name='us-east-1')
|
||||
return _memory_client
|
||||
|
||||
|
||||
def _get_compaction_flag(actor_id: str) -> bool:
|
||||
"""Check if compaction is needed for this actor."""
|
||||
ddb = boto3.resource('dynamodb', region_name='us-east-1')
|
||||
table = ddb.Table(USERS_TABLE_NAME)
|
||||
resp = table.get_item(Key={'actor_id': actor_id})
|
||||
return resp.get('Item', {}).get('needs_compaction', False)
|
||||
|
||||
|
||||
def _set_compaction_flag(actor_id: str, value: bool) -> None:
|
||||
"""Set or clear the compaction flag."""
|
||||
ddb = boto3.resource('dynamodb', region_name='us-east-1')
|
||||
table = ddb.Table(USERS_TABLE_NAME)
|
||||
table.update_item(
|
||||
Key={'actor_id': actor_id},
|
||||
UpdateExpression='SET needs_compaction = :v',
|
||||
ExpressionAttributeValues={':v': value},
|
||||
)
|
||||
|
||||
|
||||
def _count_session_events(actor_id: str, session_id: str) -> int:
|
||||
"""Count total events in the session (excluding state/agent metadata events)."""
|
||||
client = _get_memory_client()
|
||||
events = client.list_events(
|
||||
memory_id=MEMORY_ID,
|
||||
actor_id=actor_id,
|
||||
session_id=session_id,
|
||||
max_results=10000,
|
||||
include_payload=False,
|
||||
)
|
||||
# Filter out session/agent state events (they have stateType metadata)
|
||||
return sum(1 for e in events if not e.get('metadata', {}).get('stateType'))
|
||||
|
||||
|
||||
def _get_all_session_events(actor_id: str, session_id: str) -> list[dict]:
|
||||
"""Get all conversation events (excluding state metadata events)."""
|
||||
client = _get_memory_client()
|
||||
events = client.list_events(
|
||||
memory_id=MEMORY_ID,
|
||||
actor_id=actor_id,
|
||||
session_id=session_id,
|
||||
max_results=10000,
|
||||
include_payload=True,
|
||||
)
|
||||
return [e for e in events if not e.get('metadata', {}).get('stateType')]
|
||||
|
||||
|
||||
def _extract_text_from_events(events: list[dict]) -> str:
|
||||
"""Extract conversation text from events for summarization."""
|
||||
lines = []
|
||||
for event in events:
|
||||
for item in event.get('payload', []):
|
||||
if 'conversational' in item:
|
||||
conv = item['conversational']
|
||||
role = conv.get('role', 'UNKNOWN')
|
||||
text = conv.get('content', {}).get('text', '')
|
||||
lines.append(f'{role}: {text}')
|
||||
elif 'blob' in item:
|
||||
try:
|
||||
blob = json.loads(item['blob']) if isinstance(item['blob'], str) else item['blob']
|
||||
if isinstance(blob, list) and blob:
|
||||
for msg in blob:
|
||||
if isinstance(msg, (list, tuple)) and len(msg) == 2:
|
||||
lines.append(f'{msg[1]}: {msg[0]}')
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
pass
|
||||
return '\n'.join(lines[-200:]) # Cap at last 200 lines to stay within context
|
||||
|
||||
|
||||
def _call_claude_extraction(conversation_text: str) -> dict:
|
||||
"""Call Claude Haiku to extract structured LTM from conversation text."""
|
||||
bedrock = boto3.client('bedrock-runtime', region_name='us-east-1')
|
||||
prompt = (
|
||||
'Extract structured long-term memory from this conversation. '
|
||||
'Return ONLY valid JSON with these keys:\n'
|
||||
'- "summary": 3-5 sentence narrative of what was discussed\n'
|
||||
'- "facts": array of factual statements worth remembering\n'
|
||||
'- "preferences": array of user preferences expressed\n'
|
||||
'- "dates": array of events/deadlines with date/time mentioned\n'
|
||||
'- "topics": array of topic keywords\n\n'
|
||||
'Conversation:\n' + conversation_text
|
||||
)
|
||||
resp = bedrock.converse(
|
||||
modelId=HAIKU_MODEL_ID,
|
||||
messages=[{'role': 'user', 'content': [{'text': prompt}]}],
|
||||
inferenceConfig={'maxTokens': 1024},
|
||||
)
|
||||
text = resp['output']['message']['content'][0]['text']
|
||||
# Parse JSON from response (handle markdown code blocks)
|
||||
if '```' in text:
|
||||
text = text.split('```')[1]
|
||||
if text.startswith('json'):
|
||||
text = text[4:]
|
||||
return json.loads(text.strip())
|
||||
|
||||
|
||||
def _get_last_compaction_timestamp(actor_id: str) -> str | None:
|
||||
"""Get the timestamp of the most recent LTM extraction to avoid duplicates."""
|
||||
client = _get_memory_client()
|
||||
events = client.list_events(
|
||||
memory_id=MEMORY_ID,
|
||||
actor_id=actor_id,
|
||||
session_id=LTM_SESSION_ID,
|
||||
event_metadata=[{
|
||||
'left': {'metadataKey': 'type'},
|
||||
'operator': 'EQUALS_TO',
|
||||
'right': {'metadataValue': {'stringValue': 'ltm_extraction'}},
|
||||
}],
|
||||
max_results=1,
|
||||
include_payload=False,
|
||||
)
|
||||
if events:
|
||||
# Events are returned chronologically; last one is most recent
|
||||
return str(events[-1].get('eventTimestamp', ''))
|
||||
return None
|
||||
|
||||
|
||||
def check_and_compact(actor_id: str, session_id: str) -> None:
|
||||
"""Run compaction if the flag is set. Call BEFORE creating session_manager."""
|
||||
if not _get_compaction_flag(actor_id):
|
||||
return
|
||||
|
||||
logger.info('[memory_manager] Compaction triggered for actor_id=%s', actor_id)
|
||||
|
||||
try:
|
||||
events = _get_all_session_events(actor_id, session_id)
|
||||
total = len(events)
|
||||
|
||||
if total <= SESSION_WINDOW_SIZE:
|
||||
_set_compaction_flag(actor_id, False)
|
||||
return
|
||||
|
||||
# Events to compact: everything before the window
|
||||
compact_count = total - SESSION_WINDOW_SIZE
|
||||
events_to_compact = events[:compact_count]
|
||||
|
||||
# Idempotency: check if we already compacted up to this timestamp
|
||||
last_compacted = _get_last_compaction_timestamp(actor_id)
|
||||
oldest_event_ts = str(events_to_compact[-1].get('eventTimestamp', ''))
|
||||
if last_compacted and last_compacted >= oldest_event_ts:
|
||||
logger.info('[memory_manager] Already compacted up to %s, skipping', last_compacted)
|
||||
_set_compaction_flag(actor_id, False)
|
||||
return
|
||||
|
||||
# Extract text and call Claude
|
||||
text = _extract_text_from_events(events_to_compact)
|
||||
if not text.strip():
|
||||
logger.info('[memory_manager] No text to compact, clearing flag')
|
||||
_set_compaction_flag(actor_id, False)
|
||||
return
|
||||
|
||||
extraction = _call_claude_extraction(text)
|
||||
|
||||
# Store LTM extraction as an event
|
||||
client = _get_memory_client()
|
||||
client.create_event(
|
||||
memory_id=MEMORY_ID,
|
||||
actor_id=actor_id,
|
||||
session_id=LTM_SESSION_ID,
|
||||
messages=[(json.dumps(extraction), 'ASSISTANT')],
|
||||
event_timestamp=datetime.now(timezone.utc),
|
||||
metadata={
|
||||
'type': {'stringValue': 'ltm_extraction'},
|
||||
'actor_id': {'stringValue': actor_id},
|
||||
'compacted_through': {'stringValue': oldest_event_ts},
|
||||
},
|
||||
)
|
||||
|
||||
# Delete compacted events from the session
|
||||
for event in events_to_compact:
|
||||
try:
|
||||
client.gmdp_client.delete_event(
|
||||
memoryId=MEMORY_ID,
|
||||
actorId=actor_id,
|
||||
sessionId=session_id,
|
||||
eventId=event['eventId'],
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning('[memory_manager] Failed to delete event %s: %s', event.get('eventId'), e)
|
||||
|
||||
_set_compaction_flag(actor_id, False)
|
||||
logger.info('[memory_manager] Compacted %d events into LTM for actor_id=%s', compact_count, actor_id)
|
||||
|
||||
except Exception as e:
|
||||
logger.error('[memory_manager] Compaction failed: %s', e)
|
||||
# Don't clear flag — retry next invocation
|
||||
|
||||
|
||||
def check_window_and_flag(actor_id: str, session_id: str) -> None:
|
||||
"""After session loads, check if we exceed the window and set flag for next time."""
|
||||
try:
|
||||
count = _count_session_events(actor_id, session_id)
|
||||
if count > SESSION_WINDOW_SIZE:
|
||||
logger.info('[memory_manager] Session has %d events (> %d), setting compaction flag',
|
||||
count, SESSION_WINDOW_SIZE)
|
||||
_set_compaction_flag(actor_id, True)
|
||||
except Exception as e:
|
||||
logger.error('[memory_manager] Failed to check window: %s', e)
|
||||
|
||||
|
||||
def load_ltm(actor_id: str) -> str:
|
||||
"""Load all LTM extractions for an actor and format as a system prompt block.
|
||||
|
||||
Returns empty string on failure (non-fatal).
|
||||
"""
|
||||
try:
|
||||
client = _get_memory_client()
|
||||
events = client.list_events(
|
||||
memory_id=MEMORY_ID,
|
||||
actor_id=actor_id,
|
||||
session_id=LTM_SESSION_ID,
|
||||
event_metadata=[{
|
||||
'left': {'metadataKey': 'type'},
|
||||
'operator': 'EQUALS_TO',
|
||||
'right': {'metadataValue': {'stringValue': 'ltm_extraction'}},
|
||||
}],
|
||||
max_results=50,
|
||||
include_payload=True,
|
||||
)
|
||||
|
||||
if not events:
|
||||
return ''
|
||||
|
||||
# Parse extractions (events are chronological, reverse for most-recent-first)
|
||||
extractions = []
|
||||
for event in reversed(events):
|
||||
for item in event.get('payload', []):
|
||||
if 'conversational' in item:
|
||||
text = item['conversational'].get('content', {}).get('text', '')
|
||||
try:
|
||||
extractions.append(json.loads(text))
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
pass
|
||||
|
||||
if not extractions:
|
||||
return ''
|
||||
|
||||
# Build the LTM block: most recent summary first, then deduplicated lists
|
||||
parts = ['## Long-term memory\n']
|
||||
|
||||
# Most recent summary
|
||||
if extractions[0].get('summary'):
|
||||
parts.append(f'**Recent context:** {extractions[0]["summary"]}\n')
|
||||
|
||||
# Deduplicated facts
|
||||
all_facts = []
|
||||
seen_facts: set[str] = set()
|
||||
for ext in extractions:
|
||||
for f in ext.get('facts', []):
|
||||
key = f.lower().strip()
|
||||
if key not in seen_facts:
|
||||
seen_facts.add(key)
|
||||
all_facts.append(f)
|
||||
if all_facts:
|
||||
parts.append('**Facts:**')
|
||||
for f in all_facts[:30]: # Cap at 30
|
||||
parts.append(f'- {f}')
|
||||
parts.append('')
|
||||
|
||||
# Deduplicated preferences
|
||||
all_prefs = []
|
||||
seen_prefs: set[str] = set()
|
||||
for ext in extractions:
|
||||
for p in ext.get('preferences', []):
|
||||
key = p.lower().strip()
|
||||
if key not in seen_prefs:
|
||||
seen_prefs.add(key)
|
||||
all_prefs.append(p)
|
||||
if all_prefs:
|
||||
parts.append('**Preferences:**')
|
||||
for p in all_prefs[:15]:
|
||||
parts.append(f'- {p}')
|
||||
parts.append('')
|
||||
|
||||
# Dates (most recent extractions first, keep all)
|
||||
all_dates = []
|
||||
for ext in extractions:
|
||||
all_dates.extend(ext.get('dates', []))
|
||||
if all_dates:
|
||||
parts.append('**Upcoming dates/events:**')
|
||||
for d in all_dates[:10]:
|
||||
parts.append(f'- {d}')
|
||||
parts.append('')
|
||||
|
||||
result = '\n'.join(parts)
|
||||
logger.info('[memory_manager] Loaded LTM block: %d chars from %d extractions', len(result), len(extractions))
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error('[memory_manager] LTM retrieval failed (non-fatal): %s', e)
|
||||
return ''
|
||||
Reference in New Issue
Block a user