Skip to content

Commit b3a031f

Browse files
committed
Expose staker sampling as an iterator (ish)
1 parent 44cb58b commit b3a031f

File tree

6 files changed

+122
-113
lines changed

6 files changed

+122
-113
lines changed

nucypher/blockchain/eth/actors.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,8 @@
4242
PolicyManagerAgent,
4343
PreallocationEscrowAgent,
4444
StakingEscrowAgent,
45-
WorkLockAgent
45+
WorkLockAgent,
46+
StakersReservoir,
4647
)
4748
from nucypher.blockchain.eth.constants import NULL_ADDRESS
4849
from nucypher.blockchain.eth.decorators import (
@@ -1458,16 +1459,11 @@ def generate_policy_parameters(self,
14581459
payload = {**blockchain_payload, **policy_end_time}
14591460
return payload
14601461

1461-
def recruit(self, quantity: int, **options) -> List[str]:
1462+
def get_stakers_reservoir(self, **options) -> StakersReservoir:
14621463
"""
1463-
Uses sampling logic to gather stakers from the blockchain and
1464-
caches the resulting node ethereum addresses.
1465-
1466-
:param quantity: Number of ursulas to sample from the blockchain.
1467-
1464+
Get a sampler object containing the currently registered stakers.
14681465
"""
1469-
staker_addresses = self.staking_agent.sample(quantity=quantity, **options)
1470-
return staker_addresses
1466+
return self.staking_agent.get_stakers_reservoir(**options)
14711467

14721468
def create_policy(self, *args, **kwargs):
14731469
"""

nucypher/blockchain/eth/agents.py

Lines changed: 40 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -659,56 +659,23 @@ def swarm(self) -> Iterable[ChecksumAddress]:
659659
yield staker_address
660660

661661
@contract_api(CONTRACT_CALL)
662-
def sample(self,
663-
quantity: int,
664-
duration: int,
665-
pagination_size: Optional[int] = None
666-
) -> List[ChecksumAddress]:
667-
"""
668-
Select n random Stakers, according to their stake distribution.
669-
The returned addresses are shuffled.
670-
671-
See full diagram here: https://github.com/nucypher/kms-whitepaper/blob/master/pdf/miners-ruler.pdf
672-
673-
This method implements the Probability Proportional to Size (PPS) sampling algorithm.
674-
In few words, the algorithm places in a line all active stakes that have locked tokens for
675-
at least `duration` periods; a staker is selected if an input point is within its stake.
676-
For example:
677-
678-
```
679-
Stakes: |----- S0 ----|--------- S1 ---------|-- S2 --|---- S3 ---|-S4-|----- S5 -----|
680-
Points: ....R0.......................R1..................R2...............R3...........
681-
```
682-
683-
In this case, Stakers 0, 1, 3 and 5 will be selected.
684-
685-
Only stakers which made a commitment to the current period (in the previous period) are used.
686-
"""
662+
def get_stakers_reservoir(self,
663+
duration: int,
664+
without: Iterable[ChecksumAddress] = [],
665+
pagination_size: Optional[int] = None) -> 'StakersReservoir':
666+
n_tokens, stakers_map = self.get_all_active_stakers(periods=duration,
667+
pagination_size=pagination_size)
687668

688-
n_tokens, stakers_map = self.get_all_active_stakers(periods=duration, pagination_size=pagination_size)
669+
self.log.debug(f"Got {len(stakers_map)} stakers with {n_tokens} total tokens")
689670

690-
# TODO: can be implemented as an iterator if necessary, where the user can
691-
# sample addresses one by one without calling get_all_active_stakers() repeatedly.
671+
for address in without:
672+
del stakers_map[address]
692673

693-
if n_tokens == 0:
674+
# TODO: or is it enough to just make sure the number of remaining stakers is non-zero?
675+
if sum(stakers_map.values()) == 0:
694676
raise self.NotEnoughStakers('There are no locked tokens for duration {}.'.format(duration))
695677

696-
if quantity > len(stakers_map):
697-
raise self.NotEnoughStakers(f'Cannot sample {quantity} out of {len(stakers)} total stakers')
698-
699-
addresses = list(stakers_map.keys())
700-
tokens = list(stakers_map.values())
701-
sampler = WeightedSampler(addresses, tokens)
702-
703-
system_random = random.SystemRandom()
704-
sampled_addresses = sampler.sample_no_replacement(system_random, quantity)
705-
706-
# Randomize the output to avoid the largest stakers always being the first in the list
707-
system_random.shuffle(sampled_addresses) # inplace
708-
709-
self.log.debug(f"Sampled {len(addresses)} stakers: {list(sampled_addresses)}")
710-
711-
return sampled_addresses
678+
return StakersReservoir(stakers_map)
712679

713680
@contract_api(CONTRACT_CALL)
714681
def get_completed_work(self, bidder_address: ChecksumAddress) -> Work:
@@ -1584,7 +1551,10 @@ def sample_no_replacement(self, rng, quantity: int) -> list:
15841551
(does not mutate the object and only applies to the current invocation of the method).
15851552
"""
15861553

1587-
if quantity > len(self.totals):
1554+
if quantity == 0:
1555+
return []
1556+
1557+
if quantity > len(self):
15881558
raise ValueError("Cannot sample more than the total amount of elements without replacement")
15891559

15901560
totals = self.totals.copy()
@@ -1603,3 +1573,27 @@ def sample_no_replacement(self, rng, quantity: int) -> list:
16031573
totals[j] -= weight
16041574

16051575
return samples
1576+
1577+
def __len__(self):
1578+
return len(self.totals)
1579+
1580+
1581+
class StakersReservoir:
1582+
1583+
def __init__(self, stakers_map):
1584+
addresses = list(stakers_map.keys())
1585+
tokens = list(stakers_map.values())
1586+
self._sampler = WeightedSampler(addresses, tokens)
1587+
self._rng = random.SystemRandom()
1588+
1589+
def __len__(self):
1590+
return len(self._sampler)
1591+
1592+
def draw(self, quantity):
1593+
if quantity > len(self):
1594+
raise StakingEscrowAgent.NotEnoughStakers(f'Cannot sample {quantity} out of {len(self)} total stakers')
1595+
1596+
return self._sampler.sample_no_replacement(self._rng, quantity)
1597+
1598+
def draw_at_most(self, quantity):
1599+
return self.draw(min(quantity, len(self)))

nucypher/network/nodes.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
from twisted.internet import defer, reactor, task
3636
from twisted.internet.threads import deferToThread
3737
from twisted.logger import Logger
38-
from typing import Set, Tuple, Union
38+
from typing import Set, Tuple, Union, Iterable
3939
from umbral.signing import Signature
4040

4141
import nucypher
@@ -604,9 +604,10 @@ def keep_learning_about_nodes(self):
604604
# TODO: Allow the user to set eagerness? 1712
605605
self.learn_from_teacher_node(eager=False)
606606

607-
def learn_about_specific_nodes(self, addresses: Set):
608-
self._node_ids_to_learn_about_immediately.update(addresses) # hmmmm
609-
self.learn_about_nodes_now()
607+
def learn_about_specific_nodes(self, addresses: Iterable):
608+
if len(addresses) > 0:
609+
self._node_ids_to_learn_about_immediately.update(addresses) # hmmmm
610+
self.learn_about_nodes_now()
610611

611612
# TODO: Dehydrate these next two methods. NRN
612613

nucypher/policy/policies.py

Lines changed: 59 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from bytestring_splitter import BytestringSplitter, VariableLengthBytestring
2424
from constant_sorrow.constants import NOT_SIGNED, UNKNOWN_KFRAG
2525
from twisted.logger import Logger
26-
from typing import Generator, List, Set
26+
from typing import Generator, List, Set, Optional
2727
from umbral.keys import UmbralPublicKey
2828
from umbral.kfrags import KFrag
2929

@@ -381,7 +381,7 @@ def consider_arrangement(self, network_middleware, ursula, arrangement) -> bool:
381381

382382
def make_arrangements(self,
383383
network_middleware: RestMiddleware,
384-
handpicked_ursulas: Set[Ursula] = None,
384+
handpicked_ursulas: Optional[Set[Ursula]] = None,
385385
*args, **kwargs,
386386
) -> None:
387387

@@ -408,11 +408,12 @@ def make_arrangement(self, ursula: Ursula, *args, **kwargs):
408408
raise NotImplementedError
409409

410410
@abstractmethod
411-
def sample_essential(self, quantity: int, handpicked_ursulas: Set[Ursula] = None) -> Set[Ursula]:
411+
def sample_essential(self, quantity: int, handpicked_ursulas: Set[Ursula]) -> Set[Ursula]:
412412
raise NotImplementedError
413413

414-
def sample(self, handpicked_ursulas: Set[Ursula] = None) -> Set[Ursula]:
415-
selected_ursulas = set(handpicked_ursulas) if handpicked_ursulas else set()
414+
def sample(self, handpicked_ursulas: Optional[Set[Ursula]] = None) -> Set[Ursula]:
415+
handpicked_ursulas = handpicked_ursulas if handpicked_ursulas else set()
416+
selected_ursulas = set(handpicked_ursulas)
416417

417418
# Calculate the target sample quantity
418419
target_sample_quantity = self.n - len(selected_ursulas)
@@ -475,11 +476,11 @@ def make_arrangements(self, *args, **kwargs) -> None:
475476
"Pass them here as handpicked_ursulas.".format(self.n)
476477
raise self.MoreKFragsThanArrangements(error) # TODO: NotEnoughUrsulas where in the exception tree is this?
477478

478-
def sample_essential(self, quantity: int, handpicked_ursulas: Set[Ursula] = None) -> Set[Ursula]:
479+
def sample_essential(self, quantity: int, handpicked_ursulas: Set[Ursula]) -> Set[Ursula]:
479480
known_nodes = self.alice.known_nodes
480481
if handpicked_ursulas:
481482
# Prevent re-sampling of handpicked ursulas.
482-
known_nodes = set(known_nodes) - set(handpicked_ursulas)
483+
known_nodes = set(known_nodes) - handpicked_ursulas
483484
sampled_ursulas = set(random.sample(k=quantity, population=list(known_nodes)))
484485
return sampled_ursulas
485486

@@ -572,57 +573,68 @@ def generate_policy_parameters(n: int,
572573
params = dict(rate=rate, value=value)
573574
return params
574575

575-
def __find_ursulas(self,
576-
ether_addresses: List[str],
577-
target_quantity: int,
578-
timeout: int = 10) -> set: # TODO #843: Make timeout configurable
576+
def sample_essential(self,
577+
quantity: int,
578+
handpicked_ursulas: Set[Ursula],
579+
learner_timeout: int = 1,
580+
timeout: int = 10) -> Set[Ursula]:
579581

580-
start_time = maya.now() # marker for timeout calculation
582+
selected_addresses = set(handpicked_ursulas)
583+
quantity_remaining = quantity
581584

582-
found_ursulas, unknown_addresses = set(), deque()
583-
while len(found_ursulas) < target_quantity: # until there are enough Ursulas
585+
# Need to sample some stakers
584586

585-
delta = maya.now() - start_time # check for a timeout
586-
if delta.total_seconds() >= timeout:
587-
missing_nodes = ', '.join(a for a in unknown_addresses)
588-
raise RuntimeError("Timed out after {} seconds; Cannot find {}.".format(timeout, missing_nodes))
587+
reservoir = self.alice.get_stakers_reservoir(duration=self.duration_periods,
588+
without=handpicked_ursulas)
589+
if len(reservoir) < quantity_remaining:
590+
error = f"Cannot create policy with {quantity} arrangements"
591+
raise self.NotEnoughBlockchainUrsulas(error)
589592

590-
# Select an ether_address: Prefer the selection pool, then unknowns queue
591-
if ether_addresses:
592-
ether_address = ether_addresses.pop()
593-
else:
594-
ether_address = unknown_addresses.popleft()
593+
to_check = reservoir.draw(quantity_remaining)
595594

596-
try:
597-
# Check if this is a known node.
598-
selected_ursula = self.alice.known_nodes[ether_address]
595+
# Sample stakers in a loop and feed them to the learner to check
596+
# until we have enough in `selected_addresses`.
599597

600-
except KeyError:
601-
# Unknown Node
602-
self.alice.learn_about_specific_nodes({ether_address}) # enter address in learning loop
603-
unknown_addresses.append(ether_address)
604-
continue
598+
start_time = maya.now()
599+
new_to_check = to_check
600+
601+
while True:
602+
603+
# Check if the sampled addresses are already known.
604+
# If we're lucky, we won't have to wait for the learner iteration to finish.
605+
known = list(filter(lambda x: x in self.alice.known_nodes, to_check))
606+
to_check = list(filter(lambda x: x not in self.alice.known_nodes, to_check))
605607

608+
known = known[:min(len(known), quantity_remaining)] # we only need so many
609+
selected_addresses.update(known)
610+
quantity_remaining -= len(known)
611+
612+
if quantity_remaining == 0:
613+
break
606614
else:
607-
# Known Node
608-
found_ursulas.add(selected_ursula) # We already knew, or just learned about this ursula
615+
new_to_check = reservoir.draw_at_most(quantity_remaining)
616+
to_check.extend(new_to_check)
609617

610-
return found_ursulas
618+
# Feed newly sampled stakers to the learner
619+
self.alice.learn_about_specific_nodes(new_to_check)
611620

612-
def sample_essential(self, quantity: int, handpicked_ursulas: Set[Ursula] = None) -> Set[Ursula]:
613-
# TODO: Prevent re-sampling of handpicked ursulas.
614-
selected_addresses = set()
615-
try:
616-
sampled_addresses = self.alice.recruit(quantity=quantity,
617-
duration=self.duration_periods)
618-
except StakingEscrowAgent.NotEnoughStakers as e:
619-
error = f"Cannot create policy with {quantity} arrangements: {e}"
620-
raise self.NotEnoughBlockchainUrsulas(error)
621+
# TODO: would be nice to wait for the learner to finish an iteration here,
622+
# because if it hasn't, we really have nothing to do.
623+
time.sleep(learner_timeout)
624+
625+
delta = maya.now() - start_time
626+
if delta.total_seconds() >= timeout:
627+
still_checking = ', '.join(to_check)
628+
raise RuntimeError(f"Timed out after {timeout} seconds; "
629+
f"need {quantity} more, still checking {still_checking}.")
630+
631+
found_ursulas = list(selected_addresses)
632+
633+
# Randomize the output to avoid the largest stakers always being the first in the list
634+
system_random = random.SystemRandom()
635+
system_random.shuffle(found_ursulas) # inplace
621636

622-
# Capture the selection and search the network for those Ursulas
623-
selected_addresses.update(sampled_addresses)
624-
found_ursulas = self.__find_ursulas(sampled_addresses, quantity)
625-
return found_ursulas
637+
return set(found_ursulas)
626638

627639
def publish_to_blockchain(self) -> dict:
628640

tests/acceptance/blockchain/agents/test_sampling_distribution.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,8 @@ def test_sampling_distribution(testerchain, token, deploy_contract, token_econom
117117
sampled, failed = 0, 0
118118
while sampled < SAMPLES:
119119
try:
120-
addresses = set(staking_agent.sample(quantity=quantity, duration=1))
120+
reservoir = staking_agent.get_stakers_reservoir(duration=1)
121+
addresses = set(reservoir.draw(quantity))
121122
addresses.discard(NULL_ADDRESS)
122123
except staking_agent.NotEnoughStakers:
123124
failed += 1

tests/acceptance/blockchain/agents/test_staking_escrow_agent.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -150,20 +150,25 @@ def test_sample_stakers(agency):
150150
_token_agent, staking_agent, _policy_agent = agency
151151
stakers_population = staking_agent.get_staker_population()
152152

153+
reservoir = staking_agent.get_stakers_reservoir(duration=1)
153154
with pytest.raises(StakingEscrowAgent.NotEnoughStakers):
154-
staking_agent.sample(quantity=stakers_population + 1, duration=1) # One more than we have deployed
155+
reservoir.draw(stakers_population + 1) # One more than we have deployed
155156

156-
stakers = staking_agent.sample(quantity=3, duration=5)
157+
reservoir = staking_agent.get_stakers_reservoir(duration=5)
158+
stakers = reservoir.draw(3)
157159
assert len(stakers) == 3 # Three...
158160
assert len(set(stakers)) == 3 # ...unique addresses
159161

160162
# Same but with pagination
161-
stakers = staking_agent.sample(quantity=3, duration=5, pagination_size=1)
163+
reservoir = staking_agent.get_stakers_reservoir(duration=5, pagination_size=1)
164+
stakers = reservoir.draw(3)
162165
assert len(stakers) == 3
163166
assert len(set(stakers)) == 3
164167
light = staking_agent.blockchain.is_light
165168
staking_agent.blockchain.is_light = not light
166-
stakers = staking_agent.sample(quantity=3, duration=5)
169+
170+
reservoir = staking_agent.get_stakers_reservoir(duration=5)
171+
stakers = reservoir.draw(3)
167172
assert len(stakers) == 3
168173
assert len(set(stakers)) == 3
169174
staking_agent.blockchain.is_light = light
@@ -261,13 +266,13 @@ def test_lock_restaking(agency, testerchain, test_registry):
261266
staking_agent = ContractAgency.get_agent(StakingEscrowAgent, registry=test_registry)
262267
current_period = staking_agent.get_current_period()
263268
terminal_period = current_period + 2
264-
269+
265270
assert staking_agent.is_restaking(staker_account)
266271
assert not staking_agent.is_restaking_locked(staker_account)
267272
receipt = staking_agent.lock_restaking(staker_account, release_period=terminal_period)
268273
assert receipt['status'] == 1, "Transaction Rejected"
269274
assert staking_agent.is_restaking_locked(staker_account)
270-
275+
271276
testerchain.time_travel(periods=2) # Wait for re-staking lock to be released.
272277
assert not staking_agent.is_restaking_locked(staker_account)
273278

0 commit comments

Comments
 (0)