Source code for agentscope_runtime.engine.deployers.utils.service_utils.interrupt.local_backend

# -*- coding: utf-8 -*-
import asyncio
import time
from typing import AsyncGenerator, Dict, Optional, Set, Tuple

from .base_backend import BaseInterruptBackend, TaskState


[docs] class LocalInterruptBackend(BaseInterruptBackend): """ An in-memory implementation of BaseInterruptBackend using asyncio primitives. Suitable for single-process environments where Redis is not available. """
[docs] def __init__(self) -> None: self._states: Dict[str, Tuple[TaskState, float]] = {} self._subscribers: Dict[str, Set[asyncio.Queue]] = {} # Lock to ensure atomicity of compound operations (Read-Modify-Write) self._lock = asyncio.Lock()
def _get_full_key(self, key: str) -> str: """Internal helper to maintain consistent key prefixing.""" return f"state:{key}" def _get_state_logic(self, key: str) -> Optional[TaskState]: """Internal non-async logic to retrieve state with lazy expiration.""" full_key = self._get_full_key(key) record = self._states.get(full_key) if not record: return None state, expire_at = record if time.time() > expire_at: del self._states[full_key] return None return state def _set_state_logic(self, key: str, state: TaskState, ttl: int) -> None: """ Internal non-async logic to store state with an expiration timestamp. """ full_key = self._get_full_key(key) self._states[full_key] = (state, time.time() + ttl)
[docs] async def get_task_state(self, key: str) -> Optional[TaskState]: """Retrieve task state with thread-safe lock protection.""" async with self._lock: return self._get_state_logic(key)
[docs] async def set_task_state( self, key: str, state: TaskState, ttl: int = 3600, ) -> None: """Store task state with thread-safe lock protection.""" async with self._lock: self._set_state_logic(key, state, ttl)
[docs] async def compare_and_set_state( self, key: str, new_state: TaskState, expected_state: TaskState, negate: bool = False, ttl: int = 3600, ) -> bool: """ Atomic Compare-And-Set (CAS) operation. Updates state only if the current state satisfies the expected condition. """ async with self._lock: current = self._get_state_logic(key) match = current == expected_state condition_met = not match if negate else match if condition_met: self._set_state_logic(key, new_state, ttl) return True return False
[docs] async def delete_task_state(self, key: str) -> None: """Manually remove a task state record.""" async with self._lock: full_key = self._get_full_key(key) self._states.pop(full_key, None)
[docs] async def publish_event(self, channel: str, message: str) -> None: """Broadcast a message to all queues subscribed to the channel.""" if channel in self._subscribers: queues = list(self._subscribers[channel]) for queue in queues: await queue.put(message)
[docs] async def subscribe_listen( self, channel: str, ) -> AsyncGenerator[str, None]: """Subscribe to a channel and yield incoming messages.""" queue: asyncio.Queue[Optional[str]] = asyncio.Queue() self._subscribers.setdefault(channel, set()).add(queue) try: while True: message = await queue.get() # If we receive None, it means the backend is closing if message is None: break yield message finally: # Standard cleanup if channel in self._subscribers: self._subscribers[channel].discard(queue) if not self._subscribers[channel]: del self._subscribers[channel]
[docs] async def aclose(self) -> None: """Release all resources and clear internal storage.""" async with self._lock: # 1. Notify all subscribers to stop for channel_queues in self._subscribers.values(): for q in channel_queues: await q.put(None) # Send the shutdown signal # 2. Clear the storage self._states.clear() self._subscribers.clear()