Skip to content

Commit

Permalink
Store local and remote inputs/output in separate vectors
Browse files Browse the repository at this point in the history
  • Loading branch information
optout21 committed Apr 16, 2024
1 parent 50dce76 commit ef18e92
Showing 1 changed file with 44 additions and 40 deletions.
84 changes: 44 additions & 40 deletions lightning/src/ln/interactivetxs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,10 @@ struct NegotiationContext {
holder_is_initiator: bool,
received_tx_add_input_count: u16,
received_tx_add_output_count: u16,
/// The inputs to be contributed by the holder.
inputs: HashMap<SerialId, InteractiveTxInput>,
/// The inputs contributed by the holder
local_inputs: HashMap<SerialId, InteractiveTxInput>,
/// The inputs contributed by the counterparty
remote_inputs: HashMap<SerialId, InteractiveTxInput>,
/// The output intended to be the new funding output.
/// When an output added to the same pubkey, it will be treated as the shared output.
/// The script pubkey is used to discriminate which output is the funding output.
Expand All @@ -108,8 +110,10 @@ struct NegotiationContext {
/// Note: this output is also included in `outputs`.
actual_new_funding_output: Option<SharedOutput>,
prevtx_outpoints: HashSet<OutPoint>,
/// The outputs to be contributed by the holder (excluding the funding output)
outputs: HashMap<SerialId, InteractiveTxOutput>,
/// The outputs contributed by the holder
local_outputs: HashMap<SerialId, InteractiveTxOutput>,
/// The outputs contributed by the counterparty
remote_outputs: HashMap<SerialId, InteractiveTxOutput>,
/// The locktime of the funding transaction.
tx_locktime: AbsoluteLockTime,
/// The fee rate used for the transaction
Expand All @@ -129,12 +133,14 @@ impl NegotiationContext {
holder_is_initiator,
received_tx_add_input_count: 0,
received_tx_add_output_count: 0,
inputs: new_hash_map(),
local_inputs: new_hash_map(),
remote_inputs: new_hash_map(),
intended_new_funding_output,
intended_local_contribution_satoshis,
actual_new_funding_output: None,
prevtx_outpoints: new_hash_set(),
outputs: new_hash_map(),
local_outputs: new_hash_map(),
remote_outputs: new_hash_map(),
tx_locktime,
feerate_sat_per_kw,
}
Expand Down Expand Up @@ -191,24 +197,16 @@ impl NegotiationContext {
self.holder_is_initiator == serial_id.is_for_non_initiator()
}

fn total_input_and_output_count(&self) -> usize {
self.inputs.len().saturating_add(self.outputs.len())
fn total_input_count(&self) -> usize {
self.local_inputs.len().saturating_add(self.remote_inputs.len())
}

fn counterparty_inputs_contributed(&self) -> impl Iterator<Item = &InteractiveTxInput> + Clone {
self.inputs
.iter()
.filter(move |(serial_id, _)| self.is_serial_id_valid_for_counterparty(serial_id))
.map(|(_, input_with_prevout)| input_with_prevout)
fn total_output_count(&self) -> usize {
self.local_outputs.len().saturating_add(self.remote_outputs.len())
}

fn counterparty_outputs_contributed(
&self,
) -> impl Iterator<Item = &InteractiveTxOutput> + Clone {
self.outputs
.iter()
.filter(move |(serial_id, _)| self.is_serial_id_valid_for_counterparty(serial_id))
.map(|(_, output)| output)
fn total_input_and_output_count(&self) -> usize {
self.total_input_count().saturating_add(self.total_output_count())
}

fn received_tx_add_input(&mut self, msg: &msgs::TxAddInput) -> Result<(), AbortReason> {
Expand Down Expand Up @@ -265,7 +263,7 @@ impl NegotiationContext {
}

let prev_outpoint = OutPoint { txid, vout: msg.prevtx_out };
match self.inputs.entry(msg.serial_id) {
match self.remote_inputs.entry(msg.serial_id) {
hash_map::Entry::Occupied(_) => {
// The receiving node:
// - MUST fail the negotiation if:
Expand Down Expand Up @@ -303,7 +301,7 @@ impl NegotiationContext {
return Err(AbortReason::IncorrectSerialIdParity);
}

self.inputs
self.remote_inputs
.remove(&msg.serial_id)
// The receiving node:
// - MUST fail the negotiation if:
Expand Down Expand Up @@ -339,7 +337,7 @@ impl NegotiationContext {
// Check that adding this output would not cause the total output value to exceed the total
// bitcoin supply.
let mut outputs_value: u64 = 0;
for output in self.outputs.iter() {
for output in self.local_outputs.iter().chain(self.remote_outputs.iter()) {
outputs_value = outputs_value.saturating_add(output.1.value());
}
if outputs_value.saturating_add(msg.sats) > TOTAL_BITCOIN_SUPPLY_SATOSHIS {
Expand Down Expand Up @@ -377,7 +375,7 @@ impl NegotiationContext {
} else {
InteractiveTxOutput::Remote(RemoteOutput { serial_id: msg.serial_id, txout })
};
match self.outputs.entry(msg.serial_id) {
match self.remote_outputs.entry(msg.serial_id) {
hash_map::Entry::Occupied(_) => {
// The receiving node:
// - MUST fail the negotiation if:
Expand All @@ -395,7 +393,7 @@ impl NegotiationContext {
if !self.is_serial_id_valid_for_counterparty(&msg.serial_id) {
return Err(AbortReason::IncorrectSerialIdParity);
}
if let Some(_) = self.outputs.remove(&msg.serial_id) {
if let Some(_) = self.remote_outputs.remove(&msg.serial_id) {
Ok(())
} else {
// The receiving node:
Expand Down Expand Up @@ -430,7 +428,7 @@ impl NegotiationContext {
.ok_or(AbortReason::PrevTxOutInvalid)?
.value,
});
self.inputs.insert(msg.serial_id, input);
self.local_inputs.insert(msg.serial_id, input);
Ok(())
}

Expand All @@ -443,17 +441,17 @@ impl NegotiationContext {
} else {
InteractiveTxOutput::Local(LocalOutput { serial_id: msg.serial_id, txout })
};
self.outputs.insert(msg.serial_id, output);
self.local_outputs.insert(msg.serial_id, output);
Ok(())
}

fn sent_tx_remove_input(&mut self, msg: &msgs::TxRemoveInput) -> Result<(), AbortReason> {
self.inputs.remove(&msg.serial_id);
self.local_inputs.remove(&msg.serial_id);
Ok(())
}

fn sent_tx_remove_output(&mut self, msg: &msgs::TxRemoveOutput) -> Result<(), AbortReason> {
self.outputs.remove(&msg.serial_id);
self.local_outputs.remove(&msg.serial_id);
Ok(())
}

Expand All @@ -464,11 +462,19 @@ impl NegotiationContext {
// - the peer's total input satoshis with its part of any shared input is less than their outputs
// and proportion of any shared output
let mut counterparty_value_in: u64 = 0;
for (_, input) in &self.inputs {
// Consider remote and local also, due to possible shared inputs
for (_, input) in &self.remote_inputs {
counterparty_value_in = counterparty_value_in.saturating_add(input.remote_value());
}
for (_, input) in &self.local_inputs {
counterparty_value_in = counterparty_value_in.saturating_add(input.remote_value());
}
let mut counterparty_value_out: u64 = 0;
for (_, output) in &self.outputs {
// Consider both local and remote, due to possible shared inputs
for (_, output) in &self.remote_outputs {
counterparty_value_out = counterparty_value_out.saturating_add(output.remote_value());
}
for (_, output) in &self.local_outputs {
counterparty_value_out = counterparty_value_out.saturating_add(output.remote_value());
}
if counterparty_value_in < counterparty_value_out {
Expand All @@ -477,8 +483,8 @@ impl NegotiationContext {

// - there are more than 252 inputs
// - there are more than 252 outputs
if self.inputs.len() > MAX_INPUTS_OUTPUTS_COUNT
|| self.outputs.len() > MAX_INPUTS_OUTPUTS_COUNT
if self.total_input_count() > MAX_INPUTS_OUTPUTS_COUNT
|| self.total_output_count() > MAX_INPUTS_OUTPUTS_COUNT
{
return Err(AbortReason::ExceededNumberOfInputsOrOutputs);
}
Expand All @@ -487,15 +493,13 @@ impl NegotiationContext {
const INPUT_WEIGHT: u64 = BASE_INPUT_WEIGHT + EMPTY_SCRIPT_SIG_WEIGHT;

// - the peer's paid feerate does not meet or exceed the agreed feerate (based on the minimum fee).
let mut counterparty_weight_contributed: u64 = self
.counterparty_outputs_contributed()
.map(|output| {
let mut counterparty_weight_contributed: u64 = self.remote_outputs.iter()
.map(|(_, output)| {
(8 /* value */ + output.script_pubkey().consensus_encode(&mut sink()).unwrap() as u64)
* WITNESS_SCALE_FACTOR as u64
})
.sum();
counterparty_weight_contributed +=
self.counterparty_inputs_contributed().count() as u64 * INPUT_WEIGHT;
counterparty_weight_contributed += self.remote_inputs.len() as u64 * INPUT_WEIGHT;
let counterparty_fees_contributed =
counterparty_value_in.saturating_sub(counterparty_value_out);
let mut required_counterparty_contribution_fee =
Expand All @@ -516,8 +520,8 @@ impl NegotiationContext {
}

// Inputs and outputs must be sorted by serial_id
let mut inputs = self.inputs.into_iter().collect::<Vec<_>>();
let mut outputs = self.outputs.into_iter().collect::<Vec<_>>();
let mut inputs = self.local_inputs.into_iter().chain(self.remote_inputs.into_iter()).collect::<Vec<_>>();
let mut outputs = self.local_outputs.into_iter().chain(self.remote_outputs.into_iter()).collect::<Vec<_>>();
inputs.sort_unstable_by_key(|(serial_id, _)| *serial_id);
outputs.sort_unstable_by_key(|(serial_id, _)| *serial_id);

Expand Down

0 comments on commit ef18e92

Please sign in to comment.