Skip to content
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,102 changes: 840 additions & 262 deletions Cargo.lock

Large diffs are not rendered by default.

14 changes: 7 additions & 7 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,12 @@ pin-project-lite = "0.2"
subxt = "0.29"
scale-value = "0.10.0"

# substrate
frame-election-provider-support = { git = "https://github.com/paritytech/substrate" }
pallet-election-provider-multi-phase = { git = "https://github.com/paritytech/substrate" }
sp-npos-elections = { git = "https://github.com/paritytech/substrate" }
frame-support = { git = "https://github.com/paritytech/substrate" }
sp-runtime = { git = "https://github.com/paritytech/substrate" }
# polkadot-sdk
frame-election-provider-support = { git = "https://github.com/paritytech/polkadot-sdk" }
pallet-election-provider-multi-phase = { git = "https://github.com/paritytech/polkadot-sdk" }
sp-npos-elections = { git = "https://github.com/paritytech/polkadot-sdk/" }
frame-support = { git = "https://github.com/paritytech/polkadot-sdk" }
sp-runtime = { git = "https://github.com/paritytech/polkadot-sdk" }

# prometheus
prometheus = "0.13"
Expand All @@ -37,7 +37,7 @@ once_cell = "1.18"
[dev-dependencies]
anyhow = "1"
assert_cmd = "2.0"
sp-storage = { git = "https://github.com/paritytech/substrate" }
sp-storage = { git = "https://github.com/paritytech/polkadot-sdk" }
regex = "1"

[features]
Expand Down
5 changes: 2 additions & 3 deletions src/commands/monitor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,7 @@ where

(solution, score)
},
(Err(e), _) => return Err(e),
(Err(e), _) => return Err(Error::Other(e.to_string())),
};

let best_head = get_latest_head(&api, config.listen).await?;
Expand Down Expand Up @@ -421,8 +421,7 @@ where
(Err(e), _) => {
log::warn!(
target: LOG_TARGET,
"submit_and_watch_solution failed: {:?}; skipping block: {}",
e,
"submit_and_watch_solution failed: {e}; skipping block: {}",
at.number
);
},
Expand Down
258 changes: 221 additions & 37 deletions src/epm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,24 +21,34 @@ use crate::{
helpers::{storage_at, RuntimeDispatchInfo},
opt::{BalanceIterations, Balancing, Solver},
prelude::*,
static_types,
prometheus,
static_types::{self},
};

use std::{
collections::{BTreeMap, BTreeSet},
marker::PhantomData,
};

use codec::{Decode, Encode};
use frame_election_provider_support::{NposSolution, PhragMMS, SequentialPhragmen};
use frame_support::weights::Weight;
use pallet_election_provider_multi_phase::{RawSolution, ReadySolution, SolutionOrSnapshotSize};
use frame_election_provider_support::{Get, NposSolution, PhragMMS, SequentialPhragmen};
use frame_support::{weights::Weight, BoundedVec};
use pallet_election_provider_multi_phase::{
unsigned::TrimmingStatus, RawSolution, ReadySolution, SolutionOf, SolutionOrSnapshotSize,
};
use scale_info::{PortableRegistry, TypeInfo};
use scale_value::scale::{decode_as_type, TypeId};
use sp_core::Bytes;
use sp_npos_elections::ElectionScore;
use sp_npos_elections::{ElectionScore, VoteWeight};
use subxt::{dynamic::Value, rpc::rpc_params, tx::DynamicPayload};

const EPM_PALLET_NAME: &str = "ElectionProviderMultiPhase";

type MinerVoterOf =
frame_election_provider_support::Voter<AccountId, crate::static_types::MaxVotesPerVoter>;

type RoundSnapshot = pallet_election_provider_multi_phase::RoundSnapshot<AccountId, MinerVoterOf>;
type Voters =
Vec<(AccountId, VoteWeight, BoundedVec<AccountId, crate::static_types::MaxVotesPerVoter>)>;

#[derive(Copy, Clone, Debug)]
struct EpmConstant {
Expand All @@ -62,6 +72,112 @@ impl std::fmt::Display for EpmConstant {
}
}

#[derive(Debug)]
pub struct State {
voters: Voters,
voters_by_stake: BTreeMap<VoteWeight, usize>,
}

impl State {
fn len(&self) -> usize {
self.voters_by_stake.len()
}

fn to_voters(&self) -> Voters {
self.voters.clone()
}
}

/// Represent voters that may be trimmed
///
/// The trimming works by removing the voter with the least amount of stake.
///
/// It's using an internal `BTreeMap` to determine which voter to remove next
/// and the voters Vec can't be sorted because the EPM pallet will index into it
/// when checking the solution.
#[derive(Debug)]
pub struct TrimmedVoters<T> {
state: State,
_marker: PhantomData<T>,
}

impl<T> TrimmedVoters<T>
where
T: MinerConfig<AccountId = AccountId, MaxVotesPerVoter = static_types::MaxVotesPerVoter>
+ Send
+ Sync
+ 'static,
T::Solution: Send,
{
/// Create a new `TrimmedVotes`.
pub async fn new(mut voters: Voters, desired_targets: u32) -> Result<Self, Error> {
let mut voters_by_stake = BTreeMap::new();

for (idx, (_voter, stake, _supports)) in voters.iter().enumerate() {
voters_by_stake.insert(*stake, idx);
}

while let Some((_, idx)) = voters_by_stake.pop_first() {
let rm = voters[idx].0.clone();

let mut targets = BTreeSet::new();
let active_voters =
voters_by_stake.len().try_into().expect("Voters must be < u32::MAX");

// Remove votes for an account.
for (_voter, _stake, supports) in &mut voters {
supports.retain(|a| a != &rm);
targets.extend(supports);
}

let desired_targets = desired_targets;
let targets = targets.len() as u32;

let est_weight: Weight = tokio::task::spawn_blocking(move || {
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

decided to add this in the constructor/new because it should only be needed to do once

T::solution_weight(active_voters, targets, active_voters, desired_targets)
})
.await?;
let max_weight: Weight = T::MaxWeight::get();

if est_weight.all_lt(max_weight) {
return Ok(Self { state: State { voters, voters_by_stake }, _marker: PhantomData })
}
}

return Err(Error::Feasibility("Couldn't pre-trim votes to prevent trimming".to_string()))
}

/// Clone the state and trim it, so it get can be reverted.
pub fn trim(&mut self, n: usize) -> Result<State, Error> {
let mut voters = self.state.voters.clone();
let mut voters_by_stake = self.state.voters_by_stake.clone();

for _ in 0..n {
let Some((_, idx)) = voters_by_stake.pop_first() else {
return Err(Error::Feasibility(
"Couldn't pre-trim votes to prevent trimming".to_string(),
))
};
let rm = voters[idx].0.clone();

// Remove votes for an account.
for (_voter, _stake, supports) in &mut voters {
supports.retain(|a| a != &rm);
}
}

Ok(State { voters, voters_by_stake })
}

pub fn to_voters(&self) -> Voters {
self.state.voters.clone()
}

pub fn len(&self) -> usize {
self.state.len()
}
}

/// Read the constants from the metadata and updates the static types.
pub(crate) async fn update_metadata_constants(api: &SubxtClient) -> Result<(), Error> {
const SIGNED_MAX_WEIGHT: EpmConstant = EpmConstant::new("SignedMaxWeight");
Expand Down Expand Up @@ -186,8 +302,44 @@ pub async fn snapshot_at(
}
}

/// Helper to fetch snapshot data via RPC
pub async fn mine_solution<T>(
solver: Solver,
targets: Vec<AccountId>,
voters: Voters,
desired_targets: u32,
) -> Result<(SolutionOf<T>, ElectionScore, SolutionOrSnapshotSize, TrimmingStatus), Error>
where
T: MinerConfig<AccountId = AccountId, MaxVotesPerVoter = static_types::MaxVotesPerVoter>
+ Send
+ Sync
+ 'static,
T::Solution: Send,
{
match tokio::task::spawn_blocking(move || match solver {
Solver::SeqPhragmen { iterations } => {
BalanceIterations::set(iterations);
Miner::<T>::mine_solution_with_snapshot::<
SequentialPhragmen<AccountId, Accuracy, Balancing>,
>(voters, targets, desired_targets)
},
Solver::PhragMMS { iterations } => {
BalanceIterations::set(iterations);
Miner::<T>::mine_solution_with_snapshot::<PhragMMS<AccountId, Accuracy, Balancing>>(
voters,
targets,
desired_targets,
)
},
})
.await
{
Ok(Ok(s)) => Ok(s),
Err(e) => Err(e.into()),
Ok(Err(e)) => Err(Error::Other(format!("{:?}", e))),
}
}

/// Helper to fetch snapshot data via RPC
/// and compute an NPos solution via [`pallet_election_provider_multi_phase`].
pub async fn fetch_snapshot_and_mine_solution<T>(
api: &SubxtClient,
Expand Down Expand Up @@ -219,47 +371,69 @@ where
.await?
.map(|score| score.0);

let voters = snapshot.voters.clone();
let targets = snapshot.targets.clone();
let mut voters = TrimmedVoters::<T>::new(snapshot.voters.clone(), desired_targets).await?;

log::trace!(
target: LOG_TARGET,
"mine solution: desired_targets={}, voters={}, targets={}",
let (solution, score, solution_or_snapshot_size, trim_status) = mine_solution::<T>(
solver.clone(),
snapshot.targets.clone(),
voters.to_voters(),
desired_targets,
voters.len(),
targets.len()
);
)
.await?;

let blocking_task = tokio::task::spawn_blocking(move || match solver {
Solver::SeqPhragmen { iterations } => {
BalanceIterations::set(iterations);
Miner::<T>::mine_solution_with_snapshot::<
SequentialPhragmen<AccountId, Accuracy, Balancing>,
>(voters, targets, desired_targets)
},
Solver::PhragMMS { iterations } => {
BalanceIterations::set(iterations);
Miner::<T>::mine_solution_with_snapshot::<PhragMMS<AccountId, Accuracy, Balancing>>(
voters,
targets,
desired_targets,
)
},
})
.await;
if !trim_status.is_trimmed() {
return Ok(MinedSolution {
round,
desired_targets,
snapshot,
minimum_untrusted_score,
solution,
score,
solution_or_snapshot_size,
})
}

prometheus::on_trim_attempt();

let mut l = 1;
let mut h = voters.len();
let mut best_solution = None;

while l <= h {
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

binary search to find the best "pre-trim"

let mid = ((h - l) / 2) + l;

let next_state = voters.trim(mid)?;

match blocking_task {
Ok(Ok((solution, score, solution_or_snapshot_size))) => Ok(MinedSolution {
let (solution, score, solution_or_snapshot_size, trim_status) = mine_solution::<T>(
solver.clone(),
snapshot.targets.clone(),
next_state.to_voters(),
desired_targets,
)
.await?;

if !trim_status.is_trimmed() {
best_solution = Some((solution, score, solution_or_snapshot_size));
h = mid - 1;
} else {
l = mid + 1;
}
}

if let Some((solution, score, solution_or_snapshot_size)) = best_solution {
prometheus::on_trim_success();

Ok(MinedSolution {
round,
desired_targets,
snapshot,
minimum_untrusted_score,
solution,
score,
solution_or_snapshot_size,
}),
Ok(Err(err)) => Err(Error::Other(format!("{:?}", err))),
Err(err) => Err(err.into()),
})
} else {
Err(Error::Feasibility("Couldn't pre-trim votes to prevent trimming".to_string()))
}
}

Expand Down Expand Up @@ -315,6 +489,16 @@ where
}
}

impl<T: MinerConfig> std::fmt::Debug for MinedSolution<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MinedSolution")
.field("round", &self.round)
.field("desired_targets", &self.desired_targets)
.field("score", &self.score)
.finish()
}
}

fn make_type<T: scale_info::TypeInfo + 'static>() -> (TypeId, PortableRegistry) {
let m = scale_info::MetaType::new::<T>();
let mut types = scale_info::Registry::new();
Expand Down
Loading