feat: add windowed session history + LTM extraction/retrieval
- New memory_manager.py with:
- check_and_compact: runs compaction on flagged sessions (extracts LTM via
Claude Haiku, stores as AgentCore Memory event, deletes old events)
- check_window_and_flag: sets DynamoDB flag when session > 100 events
- load_ltm: retrieves LTM extractions and formats as system prompt block
- Wired into main.py:
- Compaction runs before session_manager creation (trims old events)
- LTM block injected into system prompt
- Window check runs after session close
- SESSION_WINDOW_SIZE = 100 (named constant)
- Compaction is idempotent (uses event timestamps as cursor)
- LTM retrieval failure is non-fatal (logs and continues)
This commit is contained in:
@@ -10,6 +10,7 @@ from bedrock_agentcore.runtime import BedrockAgentCoreApp
|
|||||||
|
|
||||||
from channels.telegram import TelegramAdapter
|
from channels.telegram import TelegramAdapter
|
||||||
from prompt_builder import build_system_prompt, invalidate_prompt
|
from prompt_builder import build_system_prompt, invalidate_prompt
|
||||||
|
import memory_manager
|
||||||
from tools import web as web_tools
|
from tools import web as web_tools
|
||||||
from tools import workspace as ws_tools
|
from tools import workspace as ws_tools
|
||||||
from tools import messaging
|
from tools import messaging
|
||||||
@@ -313,6 +314,9 @@ async def main(payload: dict, context):
|
|||||||
_scheduler_module._current_actor_id = actor_id
|
_scheduler_module._current_actor_id = actor_id
|
||||||
_scheduler_module._current_chat_id = chat_id
|
_scheduler_module._current_chat_id = chat_id
|
||||||
|
|
||||||
|
# Run compaction if flagged from previous invocation (trims old events before load)
|
||||||
|
memory_manager.check_and_compact(actor_id, session_id)
|
||||||
|
|
||||||
memory_config = AgentCoreMemoryConfig(
|
memory_config = AgentCoreMemoryConfig(
|
||||||
memory_id=MEMORY_ID,
|
memory_id=MEMORY_ID,
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
@@ -352,6 +356,11 @@ async def main(payload: dict, context):
|
|||||||
user_context += f'\nEnrolled services: {", ".join(enrolled)}'
|
user_context += f'\nEnrolled services: {", ".join(enrolled)}'
|
||||||
system_prompt = build_system_prompt(user_context=user_context, actor_id=actor_id)
|
system_prompt = build_system_prompt(user_context=user_context, actor_id=actor_id)
|
||||||
|
|
||||||
|
# Inject long-term memory block before conversation history
|
||||||
|
ltm_block = memory_manager.load_ltm(actor_id)
|
||||||
|
if ltm_block:
|
||||||
|
system_prompt = system_prompt + '\n\n---\n\n' + ltm_block
|
||||||
|
|
||||||
# Inject current datetime so the model always has accurate time context
|
# Inject current datetime so the model always has accurate time context
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from zoneinfo import ZoneInfo
|
from zoneinfo import ZoneInfo
|
||||||
@@ -398,6 +407,8 @@ async def main(payload: dict, context):
|
|||||||
finally:
|
finally:
|
||||||
_typing_active = False
|
_typing_active = False
|
||||||
session_manager.close()
|
session_manager.close()
|
||||||
|
# Check if session exceeds window — flag for compaction on next invocation
|
||||||
|
memory_manager.check_window_and_flag(actor_id, session_id)
|
||||||
|
|
||||||
|
|
||||||
app.run()
|
app.run()
|
||||||
|
|||||||
320
agentclaw/app/agent_claw_main/memory_manager.py
Normal file
320
agentclaw/app/agent_claw_main/memory_manager.py
Normal file
@@ -0,0 +1,320 @@
|
|||||||
|
"""Long-term memory manager: windowed loading, compaction, and LTM retrieval."""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import boto3
|
||||||
|
|
||||||
|
from bedrock_agentcore.memory.client import MemoryClient
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
MEMORY_ID = 'agentclaw_AgentClawMemory-i7Csf776AH'
|
||||||
|
SESSION_WINDOW_SIZE = 100
|
||||||
|
USERS_TABLE_NAME = os.environ.get('USERS_TABLE_NAME', 'agent-claw-users')
|
||||||
|
LTM_SESSION_ID = 'ltm-extractions'
|
||||||
|
HAIKU_MODEL_ID = 'us.anthropic.claude-3-5-haiku-20241022-v1:0'
|
||||||
|
|
||||||
|
_memory_client: MemoryClient | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def _get_memory_client() -> MemoryClient:
|
||||||
|
global _memory_client
|
||||||
|
if _memory_client is None:
|
||||||
|
_memory_client = MemoryClient(region_name='us-east-1')
|
||||||
|
return _memory_client
|
||||||
|
|
||||||
|
|
||||||
|
def _get_compaction_flag(actor_id: str) -> bool:
|
||||||
|
"""Check if compaction is needed for this actor."""
|
||||||
|
ddb = boto3.resource('dynamodb', region_name='us-east-1')
|
||||||
|
table = ddb.Table(USERS_TABLE_NAME)
|
||||||
|
resp = table.get_item(Key={'actor_id': actor_id})
|
||||||
|
return resp.get('Item', {}).get('needs_compaction', False)
|
||||||
|
|
||||||
|
|
||||||
|
def _set_compaction_flag(actor_id: str, value: bool) -> None:
|
||||||
|
"""Set or clear the compaction flag."""
|
||||||
|
ddb = boto3.resource('dynamodb', region_name='us-east-1')
|
||||||
|
table = ddb.Table(USERS_TABLE_NAME)
|
||||||
|
table.update_item(
|
||||||
|
Key={'actor_id': actor_id},
|
||||||
|
UpdateExpression='SET needs_compaction = :v',
|
||||||
|
ExpressionAttributeValues={':v': value},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _count_session_events(actor_id: str, session_id: str) -> int:
|
||||||
|
"""Count total events in the session (excluding state/agent metadata events)."""
|
||||||
|
client = _get_memory_client()
|
||||||
|
events = client.list_events(
|
||||||
|
memory_id=MEMORY_ID,
|
||||||
|
actor_id=actor_id,
|
||||||
|
session_id=session_id,
|
||||||
|
max_results=10000,
|
||||||
|
include_payload=False,
|
||||||
|
)
|
||||||
|
# Filter out session/agent state events (they have stateType metadata)
|
||||||
|
return sum(1 for e in events if not e.get('metadata', {}).get('stateType'))
|
||||||
|
|
||||||
|
|
||||||
|
def _get_all_session_events(actor_id: str, session_id: str) -> list[dict]:
|
||||||
|
"""Get all conversation events (excluding state metadata events)."""
|
||||||
|
client = _get_memory_client()
|
||||||
|
events = client.list_events(
|
||||||
|
memory_id=MEMORY_ID,
|
||||||
|
actor_id=actor_id,
|
||||||
|
session_id=session_id,
|
||||||
|
max_results=10000,
|
||||||
|
include_payload=True,
|
||||||
|
)
|
||||||
|
return [e for e in events if not e.get('metadata', {}).get('stateType')]
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_text_from_events(events: list[dict]) -> str:
|
||||||
|
"""Extract conversation text from events for summarization."""
|
||||||
|
lines = []
|
||||||
|
for event in events:
|
||||||
|
for item in event.get('payload', []):
|
||||||
|
if 'conversational' in item:
|
||||||
|
conv = item['conversational']
|
||||||
|
role = conv.get('role', 'UNKNOWN')
|
||||||
|
text = conv.get('content', {}).get('text', '')
|
||||||
|
lines.append(f'{role}: {text}')
|
||||||
|
elif 'blob' in item:
|
||||||
|
try:
|
||||||
|
blob = json.loads(item['blob']) if isinstance(item['blob'], str) else item['blob']
|
||||||
|
if isinstance(blob, list) and blob:
|
||||||
|
for msg in blob:
|
||||||
|
if isinstance(msg, (list, tuple)) and len(msg) == 2:
|
||||||
|
lines.append(f'{msg[1]}: {msg[0]}')
|
||||||
|
except (json.JSONDecodeError, TypeError):
|
||||||
|
pass
|
||||||
|
return '\n'.join(lines[-200:]) # Cap at last 200 lines to stay within context
|
||||||
|
|
||||||
|
|
||||||
|
def _call_claude_extraction(conversation_text: str) -> dict:
|
||||||
|
"""Call Claude Haiku to extract structured LTM from conversation text."""
|
||||||
|
bedrock = boto3.client('bedrock-runtime', region_name='us-east-1')
|
||||||
|
prompt = (
|
||||||
|
'Extract structured long-term memory from this conversation. '
|
||||||
|
'Return ONLY valid JSON with these keys:\n'
|
||||||
|
'- "summary": 3-5 sentence narrative of what was discussed\n'
|
||||||
|
'- "facts": array of factual statements worth remembering\n'
|
||||||
|
'- "preferences": array of user preferences expressed\n'
|
||||||
|
'- "dates": array of events/deadlines with date/time mentioned\n'
|
||||||
|
'- "topics": array of topic keywords\n\n'
|
||||||
|
'Conversation:\n' + conversation_text
|
||||||
|
)
|
||||||
|
resp = bedrock.converse(
|
||||||
|
modelId=HAIKU_MODEL_ID,
|
||||||
|
messages=[{'role': 'user', 'content': [{'text': prompt}]}],
|
||||||
|
inferenceConfig={'maxTokens': 1024},
|
||||||
|
)
|
||||||
|
text = resp['output']['message']['content'][0]['text']
|
||||||
|
# Parse JSON from response (handle markdown code blocks)
|
||||||
|
if '```' in text:
|
||||||
|
text = text.split('```')[1]
|
||||||
|
if text.startswith('json'):
|
||||||
|
text = text[4:]
|
||||||
|
return json.loads(text.strip())
|
||||||
|
|
||||||
|
|
||||||
|
def _get_last_compaction_timestamp(actor_id: str) -> str | None:
|
||||||
|
"""Get the timestamp of the most recent LTM extraction to avoid duplicates."""
|
||||||
|
client = _get_memory_client()
|
||||||
|
events = client.list_events(
|
||||||
|
memory_id=MEMORY_ID,
|
||||||
|
actor_id=actor_id,
|
||||||
|
session_id=LTM_SESSION_ID,
|
||||||
|
event_metadata=[{
|
||||||
|
'left': {'metadataKey': 'type'},
|
||||||
|
'operator': 'EQUALS_TO',
|
||||||
|
'right': {'metadataValue': {'stringValue': 'ltm_extraction'}},
|
||||||
|
}],
|
||||||
|
max_results=1,
|
||||||
|
include_payload=False,
|
||||||
|
)
|
||||||
|
if events:
|
||||||
|
# Events are returned chronologically; last one is most recent
|
||||||
|
return str(events[-1].get('eventTimestamp', ''))
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def check_and_compact(actor_id: str, session_id: str) -> None:
|
||||||
|
"""Run compaction if the flag is set. Call BEFORE creating session_manager."""
|
||||||
|
if not _get_compaction_flag(actor_id):
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info('[memory_manager] Compaction triggered for actor_id=%s', actor_id)
|
||||||
|
|
||||||
|
try:
|
||||||
|
events = _get_all_session_events(actor_id, session_id)
|
||||||
|
total = len(events)
|
||||||
|
|
||||||
|
if total <= SESSION_WINDOW_SIZE:
|
||||||
|
_set_compaction_flag(actor_id, False)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Events to compact: everything before the window
|
||||||
|
compact_count = total - SESSION_WINDOW_SIZE
|
||||||
|
events_to_compact = events[:compact_count]
|
||||||
|
|
||||||
|
# Idempotency: check if we already compacted up to this timestamp
|
||||||
|
last_compacted = _get_last_compaction_timestamp(actor_id)
|
||||||
|
oldest_event_ts = str(events_to_compact[-1].get('eventTimestamp', ''))
|
||||||
|
if last_compacted and last_compacted >= oldest_event_ts:
|
||||||
|
logger.info('[memory_manager] Already compacted up to %s, skipping', last_compacted)
|
||||||
|
_set_compaction_flag(actor_id, False)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Extract text and call Claude
|
||||||
|
text = _extract_text_from_events(events_to_compact)
|
||||||
|
if not text.strip():
|
||||||
|
logger.info('[memory_manager] No text to compact, clearing flag')
|
||||||
|
_set_compaction_flag(actor_id, False)
|
||||||
|
return
|
||||||
|
|
||||||
|
extraction = _call_claude_extraction(text)
|
||||||
|
|
||||||
|
# Store LTM extraction as an event
|
||||||
|
client = _get_memory_client()
|
||||||
|
client.create_event(
|
||||||
|
memory_id=MEMORY_ID,
|
||||||
|
actor_id=actor_id,
|
||||||
|
session_id=LTM_SESSION_ID,
|
||||||
|
messages=[(json.dumps(extraction), 'ASSISTANT')],
|
||||||
|
event_timestamp=datetime.now(timezone.utc),
|
||||||
|
metadata={
|
||||||
|
'type': {'stringValue': 'ltm_extraction'},
|
||||||
|
'actor_id': {'stringValue': actor_id},
|
||||||
|
'compacted_through': {'stringValue': oldest_event_ts},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Delete compacted events from the session
|
||||||
|
for event in events_to_compact:
|
||||||
|
try:
|
||||||
|
client.gmdp_client.delete_event(
|
||||||
|
memoryId=MEMORY_ID,
|
||||||
|
actorId=actor_id,
|
||||||
|
sessionId=session_id,
|
||||||
|
eventId=event['eventId'],
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning('[memory_manager] Failed to delete event %s: %s', event.get('eventId'), e)
|
||||||
|
|
||||||
|
_set_compaction_flag(actor_id, False)
|
||||||
|
logger.info('[memory_manager] Compacted %d events into LTM for actor_id=%s', compact_count, actor_id)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error('[memory_manager] Compaction failed: %s', e)
|
||||||
|
# Don't clear flag — retry next invocation
|
||||||
|
|
||||||
|
|
||||||
|
def check_window_and_flag(actor_id: str, session_id: str) -> None:
|
||||||
|
"""After session loads, check if we exceed the window and set flag for next time."""
|
||||||
|
try:
|
||||||
|
count = _count_session_events(actor_id, session_id)
|
||||||
|
if count > SESSION_WINDOW_SIZE:
|
||||||
|
logger.info('[memory_manager] Session has %d events (> %d), setting compaction flag',
|
||||||
|
count, SESSION_WINDOW_SIZE)
|
||||||
|
_set_compaction_flag(actor_id, True)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error('[memory_manager] Failed to check window: %s', e)
|
||||||
|
|
||||||
|
|
||||||
|
def load_ltm(actor_id: str) -> str:
|
||||||
|
"""Load all LTM extractions for an actor and format as a system prompt block.
|
||||||
|
|
||||||
|
Returns empty string on failure (non-fatal).
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
client = _get_memory_client()
|
||||||
|
events = client.list_events(
|
||||||
|
memory_id=MEMORY_ID,
|
||||||
|
actor_id=actor_id,
|
||||||
|
session_id=LTM_SESSION_ID,
|
||||||
|
event_metadata=[{
|
||||||
|
'left': {'metadataKey': 'type'},
|
||||||
|
'operator': 'EQUALS_TO',
|
||||||
|
'right': {'metadataValue': {'stringValue': 'ltm_extraction'}},
|
||||||
|
}],
|
||||||
|
max_results=50,
|
||||||
|
include_payload=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not events:
|
||||||
|
return ''
|
||||||
|
|
||||||
|
# Parse extractions (events are chronological, reverse for most-recent-first)
|
||||||
|
extractions = []
|
||||||
|
for event in reversed(events):
|
||||||
|
for item in event.get('payload', []):
|
||||||
|
if 'conversational' in item:
|
||||||
|
text = item['conversational'].get('content', {}).get('text', '')
|
||||||
|
try:
|
||||||
|
extractions.append(json.loads(text))
|
||||||
|
except (json.JSONDecodeError, TypeError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
if not extractions:
|
||||||
|
return ''
|
||||||
|
|
||||||
|
# Build the LTM block: most recent summary first, then deduplicated lists
|
||||||
|
parts = ['## Long-term memory\n']
|
||||||
|
|
||||||
|
# Most recent summary
|
||||||
|
if extractions[0].get('summary'):
|
||||||
|
parts.append(f'**Recent context:** {extractions[0]["summary"]}\n')
|
||||||
|
|
||||||
|
# Deduplicated facts
|
||||||
|
all_facts = []
|
||||||
|
seen_facts: set[str] = set()
|
||||||
|
for ext in extractions:
|
||||||
|
for f in ext.get('facts', []):
|
||||||
|
key = f.lower().strip()
|
||||||
|
if key not in seen_facts:
|
||||||
|
seen_facts.add(key)
|
||||||
|
all_facts.append(f)
|
||||||
|
if all_facts:
|
||||||
|
parts.append('**Facts:**')
|
||||||
|
for f in all_facts[:30]: # Cap at 30
|
||||||
|
parts.append(f'- {f}')
|
||||||
|
parts.append('')
|
||||||
|
|
||||||
|
# Deduplicated preferences
|
||||||
|
all_prefs = []
|
||||||
|
seen_prefs: set[str] = set()
|
||||||
|
for ext in extractions:
|
||||||
|
for p in ext.get('preferences', []):
|
||||||
|
key = p.lower().strip()
|
||||||
|
if key not in seen_prefs:
|
||||||
|
seen_prefs.add(key)
|
||||||
|
all_prefs.append(p)
|
||||||
|
if all_prefs:
|
||||||
|
parts.append('**Preferences:**')
|
||||||
|
for p in all_prefs[:15]:
|
||||||
|
parts.append(f'- {p}')
|
||||||
|
parts.append('')
|
||||||
|
|
||||||
|
# Dates (most recent extractions first, keep all)
|
||||||
|
all_dates = []
|
||||||
|
for ext in extractions:
|
||||||
|
all_dates.extend(ext.get('dates', []))
|
||||||
|
if all_dates:
|
||||||
|
parts.append('**Upcoming dates/events:**')
|
||||||
|
for d in all_dates[:10]:
|
||||||
|
parts.append(f'- {d}')
|
||||||
|
parts.append('')
|
||||||
|
|
||||||
|
result = '\n'.join(parts)
|
||||||
|
logger.info('[memory_manager] Loaded LTM block: %d chars from %d extractions', len(result), len(extractions))
|
||||||
|
return result
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error('[memory_manager] LTM retrieval failed (non-fatal): %s', e)
|
||||||
|
return ''
|
||||||
Reference in New Issue
Block a user