Source code for agentscope_runtime.engine.deployers.utils.service_utils.interrupt.redis_backend
# -*- coding: utf-8 -*-
from typing import AsyncGenerator, Optional
import redis.asyncio as redis
from .base_backend import TaskState, BaseInterruptBackend
[docs]
class RedisInterruptBackend(BaseInterruptBackend):
[docs]
def __init__(self, redis_url: str):
self.redis_client = redis.from_url(redis_url, decode_responses=True)
[docs]
async def publish_event(
self,
channel: str,
message: str,
):
await self.redis_client.publish(channel, message)
[docs]
async def subscribe_listen(
self,
channel: str,
) -> AsyncGenerator[str, None]:
pubsub = self.redis_client.pubsub()
await pubsub.subscribe(channel)
try:
async for message in pubsub.listen():
if message["type"] == "message":
yield message["data"]
finally:
await pubsub.unsubscribe(channel)
await pubsub.aclose()
[docs]
async def set_task_state(
self,
key: str,
state: TaskState,
ttl: int = 3600,
):
await self.redis_client.set(
f"state:{key}",
state.value,
ex=ttl,
)
[docs]
async def compare_and_set_state(
self,
key: str,
new_state: TaskState,
expected_state: TaskState,
negate: bool = False,
ttl: int = 3600,
) -> bool:
"""
Implementation of atomic CAS using Lua scripting for Redis.
The script ensures that the 'Get-Compare-Set' cycle is uninterruptible.
"""
lua_script = """
local current = redis.call('get', KEYS[1])
local expected = ARGV[2]
local is_negate = (ARGV[3] == '1')
-- Check if current state matches the expected state
local match = (current == expected)
-- Determine if the condition for update is met:
local condition_met = false
if is_negate then
condition_met = not match
else
condition_met = match
end
if condition_met then
redis.call('set', KEYS[1], ARGV[1], 'ex', ARGV[4])
return 1
else
return 0
end
"""
result = await self.redis_client.eval(
lua_script,
1,
f"state:{key}",
new_state.value,
expected_state.value,
"1" if negate else "0",
ttl,
)
return bool(result)
[docs]
async def get_task_state(
self,
key: str,
) -> Optional[TaskState]:
val = await self.redis_client.get(f"state:{key}")
return TaskState(val) if val else None
[docs]
async def delete_task_state(
self,
key: str,
):
await self.redis_client.delete(f"state:{key}")
[docs]
async def aclose(self):
await self.redis_client.aclose()