diff --git a/README.md b/README.md index 20de0c3..250b8d1 100644 --- a/README.md +++ b/README.md @@ -22,10 +22,15 @@ To add the bot to your server, visit https://discordapp.com/oauth2/authorize?cli

- To add several emotes from a zip or tar archive, run @Emote Manager add-archive with an attached file. + To add several emotes from a zip or tar archive, run @Emote Manager import with an attached file. You can also pass a URL to a zip or tar archive.

+

+ @Emote Manager export [animated/static] creates a zip file of all emotes + suitable for use with the import command. +

+

@Emote Manager list gives you a list of all emotes on this server.

diff --git a/cogs/emote.py b/cogs/emote.py index e832b3e..fecb592 100644 --- a/cogs/emote.py +++ b/cogs/emote.py @@ -11,6 +11,8 @@ import operator import posixpath import traceback import urllib.parse +import zipfile +import warnings import weakref import aioec @@ -23,10 +25,14 @@ import utils import utils.archive import utils.image from utils import errors +from utils.converter import emote_type_filter from utils.paginator import ListPaginator logger = logging.getLogger(__name__) +# guilds can have duplicate emotes, so let us create zips to match +warnings.filterwarnings('ignore', module='zipfile', category=UserWarning, message=r"^Duplicate name: .*$") + class UserCancelledError(commands.UserInputError): pass @@ -192,8 +198,31 @@ class Emotes(commands.Cog): await context.send(message) - @commands.command(name='add-zip', aliases=['add-tar', 'add-from-zip', 'add-from-tar']) - async def add_archive(self, context, url=None): + @commands.command(name='export') + async def export(self, context, *, image_type: emote_type_filter = lambda _: True): + """Export all emotes from this server to a zip file, suitable for use with the import command. + + If “animated” is provided, only include animated emotes. + If “static” is provided, only include static emotes. + Otherwise, or if “all” is provided, export all emotes. + """ + emotes = list(filter(image_type, context.guild.emojis)) + if not emotes: + raise commands.BadArgument('No emotes of that type were found in this server.') + + out = io.BytesIO() + async with context.typing(): + with zipfile.ZipFile(out, 'w', compression=zipfile.ZIP_STORED) as zip: + async def store(emote): + data = await self.fetch_safe(str(emote.url), validate_headers=False) + zip.writestr(f'{emote.name}.{"gif" if emote.animated else "png"}', data) + await utils.gather_or_cancel(*(store(emote) for emote in emotes)) + + out.seek(0) + await context.send(file=discord.File(out, f'emotes-{context.guild.id}.zip')) + + @commands.command(name='import', aliases=['add-zip', 'add-tar', 'add-from-zip', 'add-from-tar']) + async def import_(self, context, url=None): """Add several emotes from a .zip or .tar archive. You may either pass a URL to an archive or upload one as an attachment. @@ -251,10 +280,10 @@ class Emotes(commands.Cog): return image_data return await self.add_safe_bytes(context, name, author_id, image_data, reason=reason) - async def fetch_safe(self, url, valid_mimetypes=None): + async def fetch_safe(self, url, valid_mimetypes=None, *, validate_headers=False): """Try to fetch a URL. On error return a string that should be sent to the user.""" try: - return await self.fetch(url, valid_mimetypes=valid_mimetypes) + return await self.fetch(url, valid_mimetypes=valid_mimetypes, validate_headers=validate_headers) except asyncio.TimeoutError: return 'Error: retrieving the image took too long.' except ValueError: @@ -290,7 +319,7 @@ class Emotes(commands.Cog): s = f'Emote {emote} successfully created' return s + ' as a GIF.' if converted else s + '.' - async def fetch(self, url, valid_mimetypes=None): + async def fetch(self, url, valid_mimetypes=None, *, validate_headers=True): valid_mimetypes = valid_mimetypes or self.IMAGE_MIMETYPES def validate_headers(response): response.raise_for_status() @@ -309,7 +338,7 @@ class Emotes(commands.Cog): except aiohttp.ClientError as exc: raise errors.EmoteManagerError('An error occurred while retrieving the file: {exc}') - await validate(self.http.head(url, timeout=self.bot.config.get('http_head_timeout', 10))) + if validate_headers: await validate(self.http.head(url, timeout=self.bot.config.get('http_head_timeout', 10))) return await validate(self.http.get(url)) async def create_emote_from_bytes(self, guild, name, author_id, image_data: bytes, *, reason=None): @@ -354,26 +383,17 @@ class Emotes(commands.Cog): await context.send(fr'Emote successfully renamed to \:{new_name}:') @commands.command(aliases=('ls', 'dir')) - async def list(self, context, animated=''): + async def list(self, context, animated: emote_type_filter = lambda _: True): """A list of all emotes on this server. The list shows each emote and its raw form. If "animated" is provided, only show animated emotes. If "static" is provided, only show static emotes. - Otherwise, show all emotes. + Otherwise, or if “all” is provided, show all emotes. """ - - animated = animated.lower() - if animated == 'animated': - pred = lambda e: e.animated - elif animated == 'static': - pred = lambda e: not e.animated - else: - pred = lambda e: True - emotes = sorted( - filter(pred, context.guild.emojis), + filter(animated, context.guild.emojis), key=lambda e: e.name.lower()) processed = [] diff --git a/utils/converter.py b/utils/converter.py new file mode 100644 index 0000000..510bbff --- /dev/null +++ b/utils/converter.py @@ -0,0 +1,11 @@ +_emote_type_predicates = { + '': lambda _: True, # allow usage as a "consume rest" converter + 'all': lambda _: True, + 'static': lambda e: not e.animated, + 'animated': lambda e: e.animated} + +def emote_type_filter(argument): + try: + return _emote_type_predicates[argument.lower()] + except KeyError: + raise commands.BadArgument('Invalid emote type. Specify “static”, “animated”, “all”.') diff --git a/utils/misc.py b/utils/misc.py index e4b5e69..d0d60a7 100644 --- a/utils/misc.py +++ b/utils/misc.py @@ -1,10 +1,12 @@ #!/usr/bin/env python3 # encoding: utf-8 -import discord - """various utilities for use within the bot""" +import asyncio + +import discord + def format_user(bot, id, *, mention=False): """Format a user ID for human readable display.""" user = bot.get_user(id) @@ -37,3 +39,16 @@ def strip_angle_brackets(string): if string.startswith('<') and string.endswith('>'): return string[1:-1] return string + +async def gather_or_cancel(*awaitables, loop=None): + """run the awaitables in the sequence concurrently. If any of them raise an exception, + propagate the first exception raised and cancel all other awaitables. + """ + gather_task = asyncio.gather(*awaitables, loop=loop) + try: + return await gather_task + except asyncio.CancelledError: + raise + except: + gather_task.cancel() + raise