Logo ryp的博客

博客

WyOJ Musume 代码

2025-09-09 10:37:11 By ryp
import asyncio
import time
import websockets
import aiohttp
import traceback
import json
import random
import base64
import aiosqlite
import re
import os

wsserver = 'ws://172.17.0.2:3001/'
ws = None
commands = dict ()
trusted_senders = [ 2681243014 ]
confirm_queue = dict ()
reply_queue = dict ()
file_queue = dict ()
sqlconn = None
GROUP_ID = 1061173220

def wyoj_enable (qq, name):
    if re.match (r'^([A-Za-z0-9_]*)$', name) is None:
        raise Exception ('wyoj_enable: Invalid username; possible injection')
    os.system (f"sudo docker exec uoj-db mysql -proot app_uoj233 -e \"update user_info set usergroup = 'U', qq = '{qq}' where username = '{name}'\"")

def wyoj_disable (name):
    if re.match (r'^([A-Za-z0-9_]*)$', name) is None:
        raise Exception ('wyoj_disable: Invalid username; possible injection')
    os.system (f"sudo docker exec uoj-db mysql -proot app_uoj233 -e \"update user_info set usergroup = 'B' where username = '{name}'\"")

def wyoj_adminize (name):
    if re.match (r'^([A-Za-z0-9_]*)$', name) is None:
        raise Exception ('wyoj_adminize: Invalid username; possible injection')
    os.system (f"sudo docker exec uoj-db mysql -proot app_uoj233 -e \"update user_info set usergroup = 'S' where username = '{name}'\"")

# 暂时不考虑嵌套。
def wyoj_deadminize (name):
    if re.match (r'^([A-Za-z0-9_]*)$', name) is None:
        raise Exception ('wyoj_deadminize: Invalid username; possible injection')
    os.system (f"sudo docker exec uoj-db mysql -proot app_uoj233 -e \"update user_info set usergroup = 'U' where username = '{name}'\"")

def wyoj_admins ():
    return os.popen ("sudo docker exec uoj-db mysql -proot app_uoj233 -e \"select username from user_info where usergroup = 'S';\"").read ()

def wyoj_banned ():
    return os.popen ("sudo docker exec uoj-db mysql -proot app_uoj233 -e \"select username from user_info where usergroup = 'B';\"").read ()

async def db_insert_bind (uid, name):
    cursor = await sqlconn.cursor ()
    await cursor.execute (f'insert into binds (qq, name) values (\'{uid}\', \'{name}\')');
    await sqlconn.commit ()
    await cursor.close ()

async def db_delete_bind (name):
    cursor = await sqlconn.cursor ()
    await cursor.execute (f'delete from binds where name = \'{name}\'');
    await sqlconn.commit ()
    await cursor.close ()

async def db_query_by_id (uid):
    cursor = await sqlconn.cursor ()
    await cursor.execute (f'select qq, name from binds where qq = \'{uid}\'')
    res = await cursor.fetchall ()
    await cursor.close ()
    return res

async def db_query_by_name (name):
    cursor = await sqlconn.cursor ()
    await cursor.execute (f'select qq, name from binds where name = \'{name}\'')
    res = await cursor.fetchall ()
    await cursor.close ()
    if len (res) > 1:
        raise Exception ('重复绑定')
    if len (res) == 1:
        return res[0]
    return None

async def db_query_binds ():
    cursor = await sqlconn.cursor ()
    await cursor.execute (f'select qq, name from binds order by qq')
    res = await cursor.fetchall ()
    await cursor.close ()
    return res

async def db_count_bound (qq):
    cursor = await sqlconn.cursor ()
    await cursor.execute (f'select count (*) from binds where qq = \'{qq}\'')
    res = await cursor.fetchone ()
    await cursor.close ()
    return res[0]

async def db_check_bound (name):
    cursor = await sqlconn.cursor ()
    await cursor.execute (f'select count (*) from binds where name = \'{name}\'')
    res = await cursor.fetchone ()
    await cursor.close ()
    return res[0] > 0

async def db_banned (qq):
    cursor = await sqlconn.cursor ()
    await cursor.execute (f'select count (*) from bans where qq = \'{qq}\'')
    res = await cursor.fetchone ()
    await cursor.close ()
    return res[0] > 0

async def db_banned_all ():
    cursor = await sqlconn.cursor ()
    await cursor.execute (f'select qq from bans')
    res = await cursor.fetchall ()
    await cursor.close ()
    return res

async def db_ban (qq):
    cursor = await sqlconn.cursor ()
    await cursor.execute (f'insert into bans (qq) values (\'{qq}\')');
    await sqlconn.commit ()
    await cursor.close ()

async def db_unban (qq):
    cursor = await sqlconn.cursor ()
    await cursor.execute (f'delete from bans where qq = \'{qq}\'');
    await sqlconn.commit ()
    await cursor.close ()

refpatt = re.compile (r'\[CQ:reply,id=([0-9]*)\]')
atpatt = re.compile (r'\[CQ:at,qq=([0-9]*)\]')
class MessageMeta:
    def __init__ (self, data):
        s = data['raw_message']
        self.data = data
        self.uid = data['user_id']
        self.gid = data['group_id']
        self.msgid = data['message_id']
        self.refs = re.findall (refpatt, s)
        self.ats = re.findall (atpatt, s)

        i = 0
        flag = False
        while i < len (s) and (flag or s[i] == '['):
            flag |= s[i] == '['
            flag &= s[i] != ']'
            i += 1
        while i < len (s) and s[i] == ' ':
            i += 1
        self.msg = s[i:]
        self.cmd = self.msg.split ()

async def reply_msg (meta, msg):
    return await ws.send (json.dumps ({
        'action': 'send_group_msg',
        'params': {
            'group_id': meta.gid,
            'message': [ { 'type': 'reply', 'data': { 'id': meta.msgid } },
                        { 'type': 'text', 'data': { 'text': msg } } ]}}))

def download_dir (s):
    return f'./download/{s}'

FILE_KEEP_MIN = 60
REPOTER_GAP = 5
CHUNK_SIZE = 1048576 * 16

async def receive_file (meta, data):
    size = int (data['message'][0]['data']['file_size'])
    url = data['message'][0]['data']['url']
    name = data['message'][0]['data']['file']
    if not name.endswith ('.zip'):
        return await reply_msg (meta, f'识别到非 ZIP 格式文件,未下载')
    await reply_msg (meta, f'识别到 ZIP 格式文件,大小 %.3lf MB 正在下载' % (size / 1048576))

    start_time = time.time ()
    rx = 0
    finished = False

    async def reporter ():
        while not finished:
            duration = time.time () - start_time
            if rx != 0:
                await reply_msg (meta, f'已下载 %.3lf MB = %.3lf MB/s = %.3lf Mbps, ETA %.2lf 秒'
                             % (rx / 1048576, rx / 1048576 / duration, rx * 8 / 1048576 / duration, duration / rx * size))
            await asyncio.sleep (REPOTER_GAP)

    asyncio.create_task (reporter ())

    async with aiohttp.ClientSession () as sess:
        async with sess.get (url) as resp:
            if resp.status != 200:
                await reply_msg (meta, f'请求 {url} 失败。HTTP 状态码 {resp.status}')
            filename = download_dir (str (meta.msgid))
            with open (filename, 'wb') as f:
                async for chunk in resp.content.iter_chunked (CHUNK_SIZE):
                    f.write (chunk)
                    rx += len (chunk)
            duration = time.time () - start_time
            finished = True
            await reply_msg (meta, f'已成功下存到 {filename},耗时 %.3lf 秒,合 %.3lf MB/s = %.3lf Mbps\n将在 {FILE_KEEP_MIN} 分钟后删除。'
                             % (duration, size / duration / 1048576, size * 8 / 1048576 / duration))
            await asyncio.sleep (FILE_KEEP_MIN * 60)
            await reply_msg (meta, f'到时,删除文件 {filename}')
            os.system (f'rm {download_dir (filename)}')

def random_token ():
    s = ''
    for i in range (16):
        s += random.choice ('abcdefghijklmnopqrstuvwxyz')
    return s
def gen_confirm_token (meta):
    while True:
        s = str (meta.uid) + random_token ()
        s = base64.b64encode (s.encode ('utf-8')).decode ('utf-8')
        if s not in confirm_queue:
            return s

CONFIRM_TIMEOUT = 60
async def confirm (meta, msg):
    token = gen_confirm_token (meta)
    future = asyncio.Future ()
    confirm_queue[token] = future
    await reply_msg (meta, f'{msg}\n{CONFIRM_TIMEOUT} 秒内 /confirm {token} 来确认')
    try:
        result = await asyncio.wait_for (future, timeout=CONFIRM_TIMEOUT)
    except asyncio.TimeoutError:
        await reply_msg (meta, f'超时,取消操作')
        result = False
    else:
        await reply_msg (meta, f'成功确认')
    finally:
        del confirm_queue[token]
    return result

def register_cmd (cmd):    
    def decorator (fn):
        if cmd in commands:
            raise Exception(f'试图第二次注册命令 {cmd}')
        commands[cmd] = fn
        def wrapper (*args, **kwargs):
            return fn (*args, **kwargs)
        return wrapper
    return decorator

def trusted_sender (fn):
    async def wrapper (*args, **kwargs):
        meta = args[0]
        if meta.uid in trusted_senders:
            return await fn (*args, **kwargs)
        return await reply_msg (args[0], f'执行该命令的权限不足')
    return wrapper

@register_cmd ('/confirm')
async def cmd_confirm (meta, arg):
    if len (arg) != 2:
        return await reply_msg (meta, f'需要令牌')
    token = arg[1]
    if token not in confirm_queue:
        return await reply_msg (meta, f'无效令牌')
    uid = base64.b64decode (token).decode ('utf-8')[:-16]
    if uid != str (meta.uid):
        return await reply_msg (meta, f'错误的确认者。期望 {uid},得到 {meta.uid}')
    confirm_queue[token].set_result (True)

@register_cmd ('/testre')
async def cmd_testre (meta, arg):
    res = await wait_reply (meta, f'请回复这条消息')
    if not res:
        await reply_msg (meta, '为啥不回复?')
    await reply_msg (meta, f'你说:{' '.join (res)}')

@register_cmd ('/ping')
async def cmd_ping (meta, arg):
    s = f'PONG {meta.uid}'
    if meta.uid in trusted_senders:
        s += ' (privileged)'
    if await db_banned (meta.uid):
        s += ' (banned)'
    await reply_msg (meta, s)

MAX_BINDS = 3
@register_cmd ('/bind')
async def cmd_bind (meta, arg):
    if len (arg) != 2:
        return await reply_msg (meta, f'需要 WyOJ 用户名\nUsage: /bind username')
    name = arg[1]
    if re.match (r'^([A-Za-z0-9_]*)$', name) is None:
        return await reply_msg (meta, '不合法用户名')
    if not await confirm (meta, f'确认将 {meta.uid} 绑定到 {name} 吗?'):
        return
    if await db_check_bound (name):
        qq, _ = await db_query_by_name (name)
        return await reply_msg (meta, f'无法绑定 {name};已绑定到 {qq}')
    if await db_count_bound (meta.uid) >= MAX_BINDS:
        return await reply_msg (meta, f'无法绑定更多账号。单个 QQ 绑定最大值为 {MAX_BINDS}')
    await db_insert_bind (meta.uid, name)
    wyoj_enable (meta.uid, name)
    await reply_msg (meta, f'成功将 {name} 绑定到 {meta.uid}')

@register_cmd ('/rebind')
@trusted_sender
async def cmd_rebind (meta, arg):
    if len (arg) != 2:
        return await reply_msg (meta, f'需要 WyOJ 用户名\nUsage: /bind username')
    name = arg[1]
    if re.match (r'^([A-Za-z0-9_]*)$', name) is None:
        return await reply_msg (meta, '不合法用户名')
    if not await db_check_bound (name):
        return await reply_msg (meta, f'无绑定')
    qq, _ = await db_query_by_name (name)
    await db_delete_bind (name)
    await db_insert_bind (qq, name)
    wyoj_enable (qq, name)
    await reply_msg (meta, f'已重新绑定')

@register_cmd ('/refresh')
@trusted_sender
async def cmd_refresh (meta, arg):
    q = await db_query_binds ()
    s = ''
    cnt = 0
    for qq, name in q:
        await db_delete_bind (name)
        await db_insert_bind (qq, name)
        wyoj_enable (qq, name)
        s += f'{qq}\t\t{name}\n'
        cnt += 1
    await reply_msg (meta, f'已成功绑定:\n{s}\n计 {cnt} 人')

@register_cmd ('/unbind')
async def cmd_unbind (meta, arg):
    if len (arg) != 2:
        return await reply_msg (meta, f'需要 WyOJ 用户名\nUsage: /unbind username')
    name = arg[1]
    if re.match (r'^([A-Za-z0-9_]*)$', name) is None:
        return await reply_msg (meta, '不合法用户名')
    if not await db_check_bound (name):
        return await reply_msg (meta, '该账号未绑定')
    qq, _ = await db_query_by_name (name)
    if qq != str (meta.uid):
        if meta.uid in trusted_senders:
            await reply_msg (meta, f'你是管理员 {meta.uid},强制解绑,等待确认')
        else:
            return await reply_msg (meta, f'{name} 与 {qq} 的绑定无法由你 ({meta.uid}) 解绑')
    if not await confirm (meta, f'确认将 {meta.uid} 从 {name} 解绑吗?'):
        return
    if not await db_check_bound (name):
        return await reply_msg (meta, f'尚无绑定,无法解绑')
    await db_delete_bind (name)
    wyoj_disable (name)
    await reply_msg (meta, f'成功将 {meta.uid} 与 {name} 解绑')

@register_cmd ('/adminize')
@trusted_sender
async def cmd_adminize (meta, arg):
    if len (arg) not in [ 2, 3 ]:
        return await reply_msg (meta, f'需要用户名及时间\nUsage: /adminize username [time-in-minutes]')
    name = arg[1]
    if re.match (r'^([A-Za-z0-9_]*)$', name) is None:
        return await reply_msg (meta, '不合法用户名')
    time = -1
    if len (arg) == 3:
        try:
            time = int (arg[2])
        except ValueError:
            return await reply_msg (meta, f'不合法的时间 {arg[2]}')
    if not await confirm (meta, f'确定要赋予 {name} 以 {time} 分钟的管理员权限吗?'):
        return
    wyoj_adminize (name)
    await reply_msg (meta, f'{name} 已被赋予管理员权限')
    if time > 0:
        await asyncio.sleep (time * 60)
        wyoj_deadminize (name)
        await reply_msg (meta, f'{name} 的管理员权限已到期')

@register_cmd ('/deadminize')
@trusted_sender
async def cmd_deadminize (meta, arg):
    if len (arg) not in [ 2, 3 ]:
        return await reply_msg (meta, f'需要用户名\nUsage: /deadminize username')
    name = arg[1]
    if re.match (r'^([A-Za-z0-9_]*)$', name) is None:
        return await reply_msg (meta, '不合法用户名')
    if not await confirm (meta, f'确定要取消 {name} 的管理员权限吗?'):
        return
    wyoj_deadminize (name)
    await reply_msg (meta, f'{name} 已被赋予管理员权限')

async def cmd_query_qq (meta, arg):
    qq = arg[2]
    try:
        int (qq)
    except ValueError:
        await reply_msg (meta, f'无效 QQ 号 {qq}')
    res = await db_query_by_id (qq)
    await reply_msg (meta, f'QQ {qq} 绑定到 {len (res)} 个账号:{", ".join (map (lambda x: x[1], res))}')

async def cmd_query_me (meta, arg):
    arg.append ('qq')
    arg.append (str (meta.uid))
    await cmd_query_qq (meta, arg)

async def cmd_query_name (meta, arg):
    name = arg[2]
    if re.match (r'^([A-Za-z0-9_]*)$', name) is None:
        return await reply_msg (meta, '不合法用户名')
    res = await db_query_by_name (name)
    if not res:
        await reply_msg (meta, f'{name} 未绑定')
    else:
        await reply_msg (meta, f'{name} 绑定到 {res[0]}')

@trusted_sender
async def cmd_query_all (meta, arg):
    await reply_msg (meta, f'绑定表:\n{"\n".join (map (lambda x: x[0] + "\t\t" + x[1], await db_query_binds ()))}')

@trusted_sender
async def cmd_query_banned (meta, arg):
    await reply_msg (meta, wyoj_banned ())

@trusted_sender
async def cmd_query_admins (meta, arg):
    await reply_msg (meta, wyoj_admins ())

@register_cmd ('/query')
async def cmd_query (meta, arg):
    if len (arg) == 1:
        return await cmd_query_me (meta, arg)
    elif len (arg) == 2:
        if arg[1] == 'all':
            return await cmd_query_all (meta, arg)
        elif arg[1] == 'banned':
            return await cmd_query_banned (meta, arg)
        elif arg[1] == 'admins':
            return await cmd_query_admins (meta, arg)
    elif len (arg) == 3:
        if arg[1] == 'qq':
            return await cmd_query_qq (meta, arg)
        elif arg[1] == 'name':
            return await cmd_query_name (meta, arg)
    await reply_msg (meta, '期望零个或两个参数\nUsage: /query || /query qq qqid || /query name username')

@register_cmd ('/ban')
@trusted_sender
async def cmd_ban (meta, arg):
    if len (arg) != 2:
        return await reply_msg (meta, f'需要一个参数\nUsage: /ban username')
    try:
        qq = str (int (arg[1]))
    except ValueError:
        return await reply_msg (meta, f'无效 QQ 号 {qq}')
    if not await confirm (meta, f'确定要封禁 {qq} 及其名下的所有账号吗?'):
        return
    await db_ban (qq)
    res = await db_query_by_id (qq)
    for _, name in res:
        wyoj_disable (name)
    await reply_msg (meta, f'被封禁的用户:{", ".join (map (lambda x: x[1], res))}')

@register_cmd ('/unban')
@trusted_sender
async def cmd_unban (meta, arg):
    if len (arg) != 2:
        return await reply_msg (meta, f'需要一个参数\nUsage: /unban username')
    try:
        qq = str (int (arg[1]))
    except ValueError:
        return await reply_msg (meta, f'无效 QQ 号 {qq}')
    if not await confirm (meta, f'确定要解封 {qq} 吗(不解封账号)?'):
        return
    await db_unban (qq)
    await reply_msg (meta, '成功解封')

@register_cmd ('/banlist')
async def cmd_banlist (meta, arg):
    await reply_msg (meta, f'被封禁的 QQ:{", ".join (map (lambda x: x[0], await db_banned_all ()))}')

@register_cmd ('/entrust')
@trusted_sender
async def cmd_entrust (meta, arg):
    global trusted_senders
    if len (meta.ats) == 0:
        await reply_msg (meta, f'需要 At 某人\nUsage: [@someone..] /entrust [@someone..]')
    if not await confirm (meta, f'确定信任 {' '.join (meta.ats)} 吗?'):
        return
    for i in meta.ats:
        try:
            int (i)
        except ValueError:
            return await reply_msg (meta, f'非法 At')
    trusted_senders += map (int, meta.ats)
    await reply_msg (meta, f'成功信任')

@register_cmd ('/detrust')
@trusted_sender
async def cmd_detrust (meta, arg):
    global trusted_senders
    if len (meta.ats) == 0:
        await reply_msg (meta, f'需要 At 某人\nUsage: [@someone..] /entrust [@someone..]')
    if not await confirm (meta, f'确定取消信任 {' '.join (meta.ats)} 吗?'):
        return
    for i in meta.ats:
        try:
            int (i)
        except ValueError:
            return await reply_msg (meta, f'非法 At')
        if int (i) not in trusted_senders:
            await reply_msg (meta, '其实 {i} 并未被信任。')
    for i in meta.ats:
        trusted_senders.remove (int (i))
    await reply_msg (meta, f'成功取消信任')

@register_cmd ('/trusts')
async def cmd_trusts (meta, arg):
    await reply_msg (meta, f'信任用户:{' '.join (map (str, trusted_senders))}')


async def run_cmd (cmd):
    proc = await asyncio.create_subprocess_shell (
            cmd,
            stdout=asyncio.subprocess.PIPE,
            stderr=asyncio.subprocess.PIPE)
    stdout, stderr = await proc.communicate ()
    return stdout.decode ('utf-8'), stderr.decode ('utf-8')

@register_cmd ('/upload')
@trusted_sender
async def cmd_upload (meta, arg):
    pid = arg[1]
    try:
        pid = int (pid)
    except ValueError:
        return await reply_msg (meta, f'非法题号')
    if len (meta.refs) != 1:
        return await reply_msg (meta, f'需要有且仅有一个指向题目数据压缩包的引用')
    if not os.path.exists (download_dir (meta.refs[0])):
        return await reply_msg (meta, f'数据文件未被下载:错误的引用或者已经超时')
    await reply_msg (meta, f'正在上传数据')
    stdout, stderr = await run_cmd (f'/home/ryp/upload {pid} {download_dir (meta.refs[0])}')
    await reply_msg (meta, f'命令执行完毕。标准输出:\n{stdout}\n\n标准错误:{stderr}\n' \
            f'若无错误,请打开 https://oj.ryp.org.cn/problem/{pid}/manage/data 点击校验配置并同步数据')

async def dispatch (meta):
    if not meta.cmd[0] in commands:
        return await reply_msg (meta, f'未知命令 {meta.cmd[0]}')
    await commands[meta.cmd[0]] (meta, meta.cmd)

async def handle_msg (data):
    try:
        data = json.loads (data)
        meta = MessageMeta (data)
        if meta.gid != GROUP_ID:
            return
        if data['message'][0]['type'] == 'file':
            return await receive_file (meta, data)
        if meta.msg[0] != '/':
            return
        print (f'收到群 {meta.gid} 的用户 {meta.uid} 的命令:{' '.join (meta.cmd)}')
        await dispatch (meta)
    except KeyError as e:
        pass
    except Exception as e:
        print (f'未知错误:{e}')
        traceback.print_exc ()

async def client_loop ():
    global ws, sqlconn

    ws = await websockets.connect (wsserver)
    sqlconn = await aiosqlite.connect ('data.db')
    print ('已成功连接到 NapCat 服务器')
    async for msg in ws:
        asyncio.create_task (handle_msg (msg))

if __name__ == '__main__':
    asyncio.run (client_loop ())

评论

暂无评论

发表评论

可以用@mike来提到mike这个用户,mike会被高亮显示。如果你真的想打“@”这个字符,请用“@@”。