Bases: TaskStore
A simple in-memory implementation of TaskStore.
Features:
- Automatic TTL-based cleanup (lazy expiration)
- Thread-safe for single-process async use
- Pagination support for list_tasks
Limitations:
- All data lost on restart
- Not suitable for distributed systems
- No persistence
For production, implement TaskStore with Redis, PostgreSQL, etc.
Source code in src/mcp/shared/experimental/tasks/in_memory_task_store.py
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217 | class InMemoryTaskStore(TaskStore):
"""A simple in-memory implementation of TaskStore.
Features:
- Automatic TTL-based cleanup (lazy expiration)
- Thread-safe for single-process async use
- Pagination support for list_tasks
Limitations:
- All data lost on restart
- Not suitable for distributed systems
- No persistence
For production, implement TaskStore with Redis, PostgreSQL, etc.
"""
def __init__(self, page_size: int = 10) -> None:
self._tasks: dict[str, StoredTask] = {}
self._page_size = page_size
self._update_events: dict[str, anyio.Event] = {}
def _calculate_expiry(self, ttl_ms: int | None) -> datetime | None:
"""Calculate expiry time from TTL in milliseconds."""
if ttl_ms is None:
return None
return datetime.now(timezone.utc) + timedelta(milliseconds=ttl_ms)
def _is_expired(self, stored: StoredTask) -> bool:
"""Check if a task has expired."""
if stored.expires_at is None:
return False
return datetime.now(timezone.utc) >= stored.expires_at
def _cleanup_expired(self) -> None:
"""Remove all expired tasks. Called lazily during access operations."""
expired_ids = [task_id for task_id, stored in self._tasks.items() if self._is_expired(stored)]
for task_id in expired_ids:
del self._tasks[task_id]
async def create_task(
self,
metadata: TaskMetadata,
task_id: str | None = None,
) -> Task:
"""Create a new task with the given metadata."""
# Cleanup expired tasks on access
self._cleanup_expired()
task = create_task_state(metadata, task_id)
if task.task_id in self._tasks:
raise ValueError(f"Task with ID {task.task_id} already exists")
stored = StoredTask(
task=task,
expires_at=self._calculate_expiry(metadata.ttl),
)
self._tasks[task.task_id] = stored
# Return a copy to prevent external modification
return Task(**task.model_dump())
async def get_task(self, task_id: str) -> Task | None:
"""Get a task by ID."""
# Cleanup expired tasks on access
self._cleanup_expired()
stored = self._tasks.get(task_id)
if stored is None:
return None
# Return a copy to prevent external modification
return Task(**stored.task.model_dump())
async def update_task(
self,
task_id: str,
status: TaskStatus | None = None,
status_message: str | None = None,
) -> Task:
"""Update a task's status and/or message."""
stored = self._tasks.get(task_id)
if stored is None:
raise ValueError(f"Task with ID {task_id} not found")
# Per spec: Terminal states MUST NOT transition to any other status
if status is not None and status != stored.task.status and is_terminal(stored.task.status):
raise ValueError(f"Cannot transition from terminal status '{stored.task.status}'")
status_changed = False
if status is not None and stored.task.status != status:
stored.task.status = status
status_changed = True
if status_message is not None:
stored.task.status_message = status_message
# Update last_updated_at on any change
stored.task.last_updated_at = datetime.now(timezone.utc)
# If task is now terminal and has TTL, reset expiry timer
if status is not None and is_terminal(status) and stored.task.ttl is not None:
stored.expires_at = self._calculate_expiry(stored.task.ttl)
# Notify waiters if status changed
if status_changed:
await self.notify_update(task_id)
return Task(**stored.task.model_dump())
async def store_result(self, task_id: str, result: Result) -> None:
"""Store the result for a task."""
stored = self._tasks.get(task_id)
if stored is None:
raise ValueError(f"Task with ID {task_id} not found")
stored.result = result
async def get_result(self, task_id: str) -> Result | None:
"""Get the stored result for a task."""
stored = self._tasks.get(task_id)
if stored is None:
return None
return stored.result
async def list_tasks(
self,
cursor: str | None = None,
) -> tuple[list[Task], str | None]:
"""List tasks with pagination."""
# Cleanup expired tasks on access
self._cleanup_expired()
all_task_ids = list(self._tasks.keys())
start_index = 0
if cursor is not None:
try:
cursor_index = all_task_ids.index(cursor)
start_index = cursor_index + 1
except ValueError:
raise ValueError(f"Invalid cursor: {cursor}")
page_task_ids = all_task_ids[start_index : start_index + self._page_size]
tasks = [Task(**self._tasks[tid].task.model_dump()) for tid in page_task_ids]
# Determine next cursor
next_cursor = None
if start_index + self._page_size < len(all_task_ids) and page_task_ids:
next_cursor = page_task_ids[-1]
return tasks, next_cursor
async def delete_task(self, task_id: str) -> bool:
"""Delete a task."""
if task_id not in self._tasks:
return False
del self._tasks[task_id]
return True
async def wait_for_update(self, task_id: str) -> None:
"""Wait until the task status changes."""
if task_id not in self._tasks:
raise ValueError(f"Task with ID {task_id} not found")
# Create a fresh event for waiting (anyio.Event can't be cleared)
self._update_events[task_id] = anyio.Event()
event = self._update_events[task_id]
await event.wait()
async def notify_update(self, task_id: str) -> None:
"""Signal that a task has been updated."""
if task_id in self._update_events:
self._update_events[task_id].set()
# --- Testing/debugging helpers ---
def cleanup(self) -> None:
"""Cleanup all tasks (useful for testing or graceful shutdown)."""
self._tasks.clear()
self._update_events.clear()
def get_all_tasks(self) -> list[Task]:
"""Get all tasks (useful for debugging). Returns copies to prevent modification."""
self._cleanup_expired()
return [Task(**stored.task.model_dump()) for stored in self._tasks.values()]
|
create_task
async
Create a new task with the given metadata.
Source code in src/mcp/shared/experimental/tasks/in_memory_task_store.py
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90 | async def create_task(
self,
metadata: TaskMetadata,
task_id: str | None = None,
) -> Task:
"""Create a new task with the given metadata."""
# Cleanup expired tasks on access
self._cleanup_expired()
task = create_task_state(metadata, task_id)
if task.task_id in self._tasks:
raise ValueError(f"Task with ID {task.task_id} already exists")
stored = StoredTask(
task=task,
expires_at=self._calculate_expiry(metadata.ttl),
)
self._tasks[task.task_id] = stored
# Return a copy to prevent external modification
return Task(**task.model_dump())
|
get_task
async
get_task(task_id: str) -> Task | None
Get a task by ID.
Source code in src/mcp/shared/experimental/tasks/in_memory_task_store.py
92
93
94
95
96
97
98
99
100
101
102 | async def get_task(self, task_id: str) -> Task | None:
"""Get a task by ID."""
# Cleanup expired tasks on access
self._cleanup_expired()
stored = self._tasks.get(task_id)
if stored is None:
return None
# Return a copy to prevent external modification
return Task(**stored.task.model_dump())
|
update_task
async
update_task(
task_id: str,
status: TaskStatus | None = None,
status_message: str | None = None,
) -> Task
Update a task's status and/or message.
Source code in src/mcp/shared/experimental/tasks/in_memory_task_store.py
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138 | async def update_task(
self,
task_id: str,
status: TaskStatus | None = None,
status_message: str | None = None,
) -> Task:
"""Update a task's status and/or message."""
stored = self._tasks.get(task_id)
if stored is None:
raise ValueError(f"Task with ID {task_id} not found")
# Per spec: Terminal states MUST NOT transition to any other status
if status is not None and status != stored.task.status and is_terminal(stored.task.status):
raise ValueError(f"Cannot transition from terminal status '{stored.task.status}'")
status_changed = False
if status is not None and stored.task.status != status:
stored.task.status = status
status_changed = True
if status_message is not None:
stored.task.status_message = status_message
# Update last_updated_at on any change
stored.task.last_updated_at = datetime.now(timezone.utc)
# If task is now terminal and has TTL, reset expiry timer
if status is not None and is_terminal(status) and stored.task.ttl is not None:
stored.expires_at = self._calculate_expiry(stored.task.ttl)
# Notify waiters if status changed
if status_changed:
await self.notify_update(task_id)
return Task(**stored.task.model_dump())
|
store_result
async
store_result(task_id: str, result: Result) -> None
Store the result for a task.
Source code in src/mcp/shared/experimental/tasks/in_memory_task_store.py
140
141
142
143
144
145
146 | async def store_result(self, task_id: str, result: Result) -> None:
"""Store the result for a task."""
stored = self._tasks.get(task_id)
if stored is None:
raise ValueError(f"Task with ID {task_id} not found")
stored.result = result
|
get_result
async
Get the stored result for a task.
Source code in src/mcp/shared/experimental/tasks/in_memory_task_store.py
148
149
150
151
152
153
154 | async def get_result(self, task_id: str) -> Result | None:
"""Get the stored result for a task."""
stored = self._tasks.get(task_id)
if stored is None:
return None
return stored.result
|
list_tasks
async
List tasks with pagination.
Source code in src/mcp/shared/experimental/tasks/in_memory_task_store.py
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182 | async def list_tasks(
self,
cursor: str | None = None,
) -> tuple[list[Task], str | None]:
"""List tasks with pagination."""
# Cleanup expired tasks on access
self._cleanup_expired()
all_task_ids = list(self._tasks.keys())
start_index = 0
if cursor is not None:
try:
cursor_index = all_task_ids.index(cursor)
start_index = cursor_index + 1
except ValueError:
raise ValueError(f"Invalid cursor: {cursor}")
page_task_ids = all_task_ids[start_index : start_index + self._page_size]
tasks = [Task(**self._tasks[tid].task.model_dump()) for tid in page_task_ids]
# Determine next cursor
next_cursor = None
if start_index + self._page_size < len(all_task_ids) and page_task_ids:
next_cursor = page_task_ids[-1]
return tasks, next_cursor
|
delete_task
async
Delete a task.
Source code in src/mcp/shared/experimental/tasks/in_memory_task_store.py
184
185
186
187
188
189
190 | async def delete_task(self, task_id: str) -> bool:
"""Delete a task."""
if task_id not in self._tasks:
return False
del self._tasks[task_id]
return True
|
wait_for_update
async
wait_for_update(task_id: str) -> None
Wait until the task status changes.
Source code in src/mcp/shared/experimental/tasks/in_memory_task_store.py
192
193
194
195
196
197
198
199
200 | async def wait_for_update(self, task_id: str) -> None:
"""Wait until the task status changes."""
if task_id not in self._tasks:
raise ValueError(f"Task with ID {task_id} not found")
# Create a fresh event for waiting (anyio.Event can't be cleared)
self._update_events[task_id] = anyio.Event()
event = self._update_events[task_id]
await event.wait()
|
notify_update
async
notify_update(task_id: str) -> None
Signal that a task has been updated.
Source code in src/mcp/shared/experimental/tasks/in_memory_task_store.py
| async def notify_update(self, task_id: str) -> None:
"""Signal that a task has been updated."""
if task_id in self._update_events:
self._update_events[task_id].set()
|
cleanup
Cleanup all tasks (useful for testing or graceful shutdown).
Source code in src/mcp/shared/experimental/tasks/in_memory_task_store.py
| def cleanup(self) -> None:
"""Cleanup all tasks (useful for testing or graceful shutdown)."""
self._tasks.clear()
self._update_events.clear()
|
get_all_tasks
Get all tasks (useful for debugging). Returns copies to prevent modification.
Source code in src/mcp/shared/experimental/tasks/in_memory_task_store.py
| def get_all_tasks(self) -> list[Task]:
"""Get all tasks (useful for debugging). Returns copies to prevent modification."""
self._cleanup_expired()
return [Task(**stored.task.model_dump()) for stored in self._tasks.values()]
|