Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 7 additions & 11 deletions sopel/bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,30 +602,26 @@ def rate_limit_info(
if trigger.admin or rule.is_unblockable():
return False, None

nick = trigger.nick
is_channel = trigger.sender and not trigger.sender.is_nick()
channel = trigger.sender if is_channel else None

at_time = trigger.time

user_metrics = rule.get_user_metrics(trigger.nick)
channel_metrics = rule.get_channel_metrics(channel)
global_metrics = rule.get_global_metrics()

if user_metrics.is_limited(at_time - rule.user_rate_limit):
if rule.is_user_rate_limited(nick, at_time):
template = rule.user_rate_template
rate_limit_type = "user"
rate_limit = rule.user_rate_limit
metrics = user_metrics
elif is_channel and channel_metrics.is_limited(at_time - rule.channel_rate_limit):
metrics = rule.get_user_metrics(nick)
elif channel and rule.is_channel_rate_limited(channel, at_time):
template = rule.channel_rate_template
rate_limit_type = "channel"
rate_limit = rule.channel_rate_limit
metrics = channel_metrics
elif global_metrics.is_limited(at_time - rule.global_rate_limit):
metrics = rule.get_channel_metrics(channel)
elif rule.is_global_rate_limited(at_time):
template = rule.global_rate_template
rate_limit_type = "global"
rate_limit = rule.global_rate_limit
metrics = global_metrics
metrics = rule.get_global_metrics()
else:
return False, None

Expand Down
57 changes: 33 additions & 24 deletions sopel/plugins/rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -765,40 +765,49 @@ def global_rate_limit(self) -> datetime.timedelta:
def is_user_rate_limited(
self,
nick: Identifier,
at_time: Optional[datetime.datetime] = None,
at_time: datetime.datetime,
) -> bool:
"""Tell when the rule reached the ``nick``'s rate limit.

:param nick: the nick associated with this check
:param at_time: optional aware datetime for the rate limit check;
if not given, ``utcnow`` will be used
:param at_time: aware datetime for the rate limit check
:return: ``True`` when the rule reached the limit, ``False`` otherwise.

.. versionchanged:: 8.0.1

Parameter ``at_time`` is now required.

"""

@abc.abstractmethod
def is_channel_rate_limited(
self,
channel: Identifier,
at_time: Optional[datetime.datetime] = None,
at_time: datetime.datetime,
) -> bool:
"""Tell when the rule reached the ``channel``'s rate limit.

:param channel: the channel associated with this check
:param at_time: optional aware datetime for the rate limit check;
if not given, ``utcnow`` will be used
:param at_time: aware datetime for the rate limit check
:return: ``True`` when the rule reached the limit, ``False`` otherwise.

.. versionchanged:: 8.0.1

Parameter ``at_time`` is now required.

"""

@abc.abstractmethod
def is_global_rate_limited(
self,
at_time: Optional[datetime.datetime] = None,
) -> bool:
def is_global_rate_limited(self, at_time: datetime.datetime) -> bool:
"""Tell when the rule reached the global rate limit.

:param at_time: optional aware datetime for the rate limit check;
if not given, ``utcnow`` will be used
:param at_time: aware datetime for the rate limit check
:return: ``True`` when the rule reached the limit, ``False`` otherwise.

.. versionchanged:: 8.0.1

Parameter ``at_time`` is now required.

"""

@property
Expand Down Expand Up @@ -1209,29 +1218,29 @@ def global_rate_limit(self) -> datetime.timedelta:
def is_user_rate_limited(
self,
nick: Identifier,
at_time: Optional[datetime.datetime] = None,
at_time: datetime.datetime,
) -> bool:
if at_time is None:
at_time = datetime.datetime.now(datetime.timezone.utc)
if self._user_rate_limit <= 0:
return False

metrics = self.get_user_metrics(nick)
return metrics.is_limited(at_time - self.user_rate_limit)

def is_channel_rate_limited(
self,
channel: Identifier,
at_time: Optional[datetime.datetime] = None,
at_time: datetime.datetime,
) -> bool:
if at_time is None:
at_time = datetime.datetime.now(datetime.timezone.utc)
if self._channel_rate_limit <= 0:
return False

metrics = self.get_channel_metrics(channel)
return metrics.is_limited(at_time - self.channel_rate_limit)

def is_global_rate_limited(
self,
at_time: Optional[datetime.datetime] = None,
) -> bool:
if at_time is None:
at_time = datetime.datetime.now(datetime.timezone.utc)
def is_global_rate_limited(self, at_time: datetime.datetime) -> bool:
if self._global_rate_limit <= 0:
return False

metrics = self.get_global_metrics()
return metrics.is_limited(at_time - self.global_rate_limit)

Expand Down
42 changes: 24 additions & 18 deletions test/plugins/test_plugins_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -1566,14 +1566,16 @@ def handler(bot, trigger):
global_rate_limit=20,
channel_rate_limit=20,
)
assert rule.is_user_rate_limited(mocktrigger.nick) is False
assert rule.is_channel_rate_limited(mocktrigger.sender) is False
assert rule.is_global_rate_limited() is False
at_time = datetime.datetime.now(datetime.timezone.utc)
assert rule.is_user_rate_limited(mocktrigger.nick, at_time) is False
assert rule.is_channel_rate_limited(mocktrigger.sender, at_time) is False
assert rule.is_global_rate_limited(at_time) is False

rule.execute(mockbot, mocktrigger)
assert rule.is_user_rate_limited(mocktrigger.nick) is True
assert rule.is_channel_rate_limited(mocktrigger.sender) is True
assert rule.is_global_rate_limited() is True
at_time = datetime.datetime.now(datetime.timezone.utc)
assert rule.is_user_rate_limited(mocktrigger.nick, at_time) is True
assert rule.is_channel_rate_limited(mocktrigger.sender, at_time) is True
assert rule.is_global_rate_limited(at_time) is True


def test_rule_rate_limit_no_limit(mockbot, triggerfactory):
Expand All @@ -1592,14 +1594,16 @@ def handler(bot, trigger):
global_rate_limit=0,
channel_rate_limit=0,
)
assert rule.is_user_rate_limited(mocktrigger.nick) is False
assert rule.is_channel_rate_limited(mocktrigger.sender) is False
assert rule.is_global_rate_limited() is False
at_time = datetime.datetime.now(datetime.timezone.utc)
assert rule.is_user_rate_limited(mocktrigger.nick, at_time) is False
assert rule.is_channel_rate_limited(mocktrigger.sender, at_time) is False
assert rule.is_global_rate_limited(at_time) is False

rule.execute(mockbot, mocktrigger)
assert rule.is_user_rate_limited(mocktrigger.nick) is False
assert rule.is_channel_rate_limited(mocktrigger.sender) is False
assert rule.is_global_rate_limited() is False
at_time = datetime.datetime.now(datetime.timezone.utc)
assert rule.is_user_rate_limited(mocktrigger.nick, at_time) is False
assert rule.is_channel_rate_limited(mocktrigger.sender, at_time) is False
assert rule.is_global_rate_limited(at_time) is False


def test_rule_rate_limit_ignore_rate_limit(mockbot, triggerfactory):
Expand All @@ -1619,14 +1623,16 @@ def handler(bot, trigger):
channel_rate_limit=20,
threaded=False, # make sure there is no race-condition here
)
assert rule.is_user_rate_limited(mocktrigger.nick) is False
assert rule.is_channel_rate_limited(mocktrigger.sender) is False
assert rule.is_global_rate_limited() is False
at_time = datetime.datetime.now(datetime.timezone.utc)
assert rule.is_user_rate_limited(mocktrigger.nick, at_time) is False
assert rule.is_channel_rate_limited(mocktrigger.sender, at_time) is False
assert rule.is_global_rate_limited(at_time) is False

rule.execute(mockbot, mocktrigger)
assert rule.is_user_rate_limited(mocktrigger.nick) is False
assert rule.is_channel_rate_limited(mocktrigger.sender) is False
assert rule.is_global_rate_limited() is False
at_time = datetime.datetime.now(datetime.timezone.utc)
assert rule.is_user_rate_limited(mocktrigger.nick, at_time) is False
assert rule.is_channel_rate_limited(mocktrigger.sender, at_time) is False
assert rule.is_global_rate_limited(at_time) is False


def test_rule_rate_limit_messages(mockbot, triggerfactory):
Expand Down
104 changes: 80 additions & 24 deletions test/test_bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@

if typing.TYPE_CHECKING:
from sopel.config import Config
from sopel.tests.factories import BotFactory, IRCFactory, UserFactory
from sopel.tests.factories import (
BotFactory, ConfigFactory, IRCFactory, TriggerFactory, UserFactory,
)
from sopel.tests.mocks import MockIRCServer


Expand Down Expand Up @@ -81,17 +83,17 @@ def ignored():


@pytest.fixture
def tmpconfig(configfactory):
def tmpconfig(configfactory: ConfigFactory) -> Config:
return configfactory('test.cfg', TMP_CONFIG)


@pytest.fixture
def mockbot(tmpconfig, botfactory):
def mockbot(tmpconfig: Config, botfactory: BotFactory) -> bot.Sopel:
return botfactory(tmpconfig)


@pytest.fixture
def mockplugin(tmpdir):
def mockplugin(tmpdir) -> plugins.handlers.PyFilePlugin:
root = tmpdir.mkdir('loader_mods')
mod_file = root.join('mockplugin.py')
mod_file.write(MOCK_MODULE_CONTENT)
Expand Down Expand Up @@ -676,7 +678,7 @@ def url_callback_http(bot, trigger, match):
# call_rule

@pytest.fixture
def match_hello_rule(mockbot, triggerfactory):
def match_hello_rule(mockbot: bot.Sopel, triggerfactory: TriggerFactory):
"""Helper for generating matches to each `Rule` in the following tests"""
def _factory(rule_hello):
# trigger
Expand All @@ -694,7 +696,25 @@ def _factory(rule_hello):
return _factory


def test_call_rule(mockbot, match_hello_rule):
@pytest.fixture
def multimatch_hello_rule(mockbot: bot.Sopel, triggerfactory: TriggerFactory):
def _factory(rule_hello):
# trigger
line = ':[email protected] PRIVMSG #channel :hello hello hello'

trigger = triggerfactory(mockbot, line)
pretrigger = trigger._pretrigger

for match in rule_hello.match(mockbot, pretrigger):
wrapper = bot.SopelWrapper(mockbot, trigger)
yield match, trigger, wrapper
return _factory


def test_call_rule(
mockbot: bot.Sopel,
match_hello_rule: typing.Callable,
) -> None:
# setup
items = []

Expand All @@ -721,9 +741,10 @@ def testrule(bot, trigger):
assert items == [1]

# assert the rule is not rate limited
assert not rule_hello.is_user_rate_limited(Identifier('Test'))
assert not rule_hello.is_channel_rate_limited('#channel')
assert not rule_hello.is_global_rate_limited()
at_time = datetime.now(timezone.utc)
assert not rule_hello.is_user_rate_limited(Identifier('Test'), at_time)
assert not rule_hello.is_channel_rate_limited('#channel', at_time)
assert not rule_hello.is_global_rate_limited(at_time)

match, rule_trigger, wrapper = match_hello_rule(rule_hello)

Expand All @@ -738,6 +759,36 @@ def testrule(bot, trigger):
assert items == [1, 1]


def test_call_rule_multiple_matches(
mockbot: bot.Sopel,
multimatch_hello_rule: typing.Callable,
) -> None:
# setup
items = []

def testrule(bot, trigger):
bot.say('hi')
items.append(1)
return "Return Value"

find_hello = rules.FindRule(
[re.compile(r'(hi|hello|hey|sup)')],
plugin='testplugin',
label='testrule',
handler=testrule)

for match, rule_trigger, wrapper in multimatch_hello_rule(find_hello):
mockbot.call_rule(find_hello, wrapper, rule_trigger)

# assert the rule has been executed three times now
assert mockbot.backend.message_sent == rawlist(
'PRIVMSG #channel :hi',
'PRIVMSG #channel :hi',
'PRIVMSG #channel :hi',
)
assert items == [1, 1, 1]


def test_call_rule_rate_limited_user(mockbot, match_hello_rule):
items = []

Expand Down Expand Up @@ -767,9 +818,10 @@ def testrule(bot, trigger):
assert items == [1]

# assert the rule is now rate limited
assert rule_hello.is_user_rate_limited(Identifier('Test'))
assert not rule_hello.is_channel_rate_limited('#channel')
assert not rule_hello.is_global_rate_limited()
at_time = datetime.now(timezone.utc)
assert rule_hello.is_user_rate_limited(Identifier('Test'), at_time)
assert not rule_hello.is_channel_rate_limited('#channel', at_time)
assert not rule_hello.is_global_rate_limited(at_time)

match, rule_trigger, wrapper = match_hello_rule(rule_hello)

Expand Down Expand Up @@ -852,9 +904,10 @@ def testrule(bot, trigger):
assert items == [1]

# assert the rule is now rate limited
assert not rule_hello.is_user_rate_limited(Identifier('Test'))
assert rule_hello.is_channel_rate_limited('#channel')
assert not rule_hello.is_global_rate_limited()
at_time = datetime.now(timezone.utc)
assert not rule_hello.is_user_rate_limited(Identifier('Test'), at_time)
assert rule_hello.is_channel_rate_limited('#channel', at_time)
assert not rule_hello.is_global_rate_limited(at_time)

match, rule_trigger, wrapper = match_hello_rule(rule_hello)

Expand Down Expand Up @@ -897,9 +950,10 @@ def testrule(bot, trigger):
assert items == [1]

# assert the rule is now rate limited
assert not rule_hello.is_user_rate_limited(Identifier('Test'))
assert rule_hello.is_channel_rate_limited('#channel')
assert not rule_hello.is_global_rate_limited()
at_time = datetime.now(timezone.utc)
assert not rule_hello.is_user_rate_limited(Identifier('Test'), at_time)
assert rule_hello.is_channel_rate_limited('#channel', at_time)
assert not rule_hello.is_global_rate_limited(at_time)

match, rule_trigger, wrapper = match_hello_rule(rule_hello)

Expand Down Expand Up @@ -942,9 +996,10 @@ def testrule(bot, trigger):
assert items == [1]

# assert the rule is now rate limited
assert not rule_hello.is_user_rate_limited(Identifier('Test'))
assert not rule_hello.is_channel_rate_limited('#channel')
assert rule_hello.is_global_rate_limited()
at_time = datetime.now(timezone.utc)
assert not rule_hello.is_user_rate_limited(Identifier('Test'), at_time)
assert not rule_hello.is_channel_rate_limited('#channel', at_time)
assert rule_hello.is_global_rate_limited(at_time)

match, rule_trigger, wrapper = match_hello_rule(rule_hello)

Expand Down Expand Up @@ -987,9 +1042,10 @@ def testrule(bot, trigger):
assert items == [1]

# assert the rule is now rate limited
assert not rule_hello.is_user_rate_limited(Identifier('Test'))
assert not rule_hello.is_channel_rate_limited('#channel')
assert rule_hello.is_global_rate_limited()
at_time = datetime.now(timezone.utc)
assert not rule_hello.is_user_rate_limited(Identifier('Test'), at_time)
assert not rule_hello.is_channel_rate_limited('#channel', at_time)
assert rule_hello.is_global_rate_limited(at_time)

match, rule_trigger, wrapper = match_hello_rule(rule_hello)

Expand Down