diff --git a/rustfmt.toml b/rustfmt.toml new file mode 100644 index 0000000..a764604 --- /dev/null +++ b/rustfmt.toml @@ -0,0 +1,79 @@ +## This is copied from https://github.com/rust-bitcoin/rust-bitcoin/blob/master/rustfmt.toml + +hard_tabs = false +tab_spaces = 4 +newline_style = "Auto" +indent_style = "Block" + +max_width = 100 # This is number of characters. +# `use_small_heuristics` is ignored if the granular width config values are explicitly set. +use_small_heuristics = "Max" # "Max" == All granular width settings same as `max_width`. +# # Granular width configuration settings. These are percentages of `max_width`. +# fn_call_width = 60 +# attr_fn_like_width = 70 +# struct_lit_width = 18 +# struct_variant_width = 35 +# array_width = 60 +# chain_width = 60 +# single_line_if_else_max_width = 50 + +wrap_comments = false +format_code_in_doc_comments = false +comment_width = 100 # Default 80 +normalize_comments = false +normalize_doc_attributes = false +format_strings = false +format_macro_matchers = false +format_macro_bodies = true +hex_literal_case = "Preserve" +empty_item_single_line = true +struct_lit_single_line = true +fn_single_line = true # Default false +where_single_line = false +imports_indent = "Block" +imports_layout = "Mixed" +imports_granularity = "Module" # Default "Preserve" +group_imports = "StdExternalCrate" # Default "Preserve" +reorder_imports = true +reorder_modules = true +reorder_impl_items = false +type_punctuation_density = "Wide" +space_before_colon = false +space_after_colon = true +spaces_around_ranges = false +binop_separator = "Front" +remove_nested_parens = true +combine_control_expr = true +overflow_delimited_expr = false +struct_field_align_threshold = 0 +enum_discrim_align_threshold = 0 +match_arm_blocks = false # Default true +match_arm_leading_pipes = "Never" +force_multiline_blocks = false +fn_args_layout = "Tall" +brace_style = "SameLineWhere" +control_brace_style = "AlwaysSameLine" +trailing_semicolon = true +trailing_comma = "Vertical" +match_block_trailing_comma = false +blank_lines_upper_bound = 1 +blank_lines_lower_bound = 0 +edition = "2018" +version = "One" +inline_attribute_width = 0 +format_generated_files = true +merge_derives = true +use_try_shorthand = false +use_field_init_shorthand = false +force_explicit_abi = true +condense_wildcard_suffixes = false +color = "Auto" +required_version = "1.5.1" +unstable_features = false +disable_all_formatting = false +skip_children = false +hide_parse_errors = false +error_on_line_overflow = false +error_on_unformatted = false +emit_mode = "Files" +make_backup = false diff --git a/src/main.rs b/src/main.rs index 83f754c..328904e 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,13 +1,13 @@ -use hyper::service::{make_service_fn, service_fn}; -use hyper::{Body, Method, Request, Response, Server, StatusCode}; use bitcoin::util::address::Address; use bitcoin::util::psbt::PartiallySignedTransaction; -use bitcoin::{TxOut, Script}; -use std::convert::TryInto; -use std::sync::{Arc, Mutex}; +use bitcoin::{Script, TxOut}; +use hyper::service::{make_service_fn, service_fn}; +use hyper::{Body, Method, Request, Response, Server, StatusCode}; +use ln_types::P2PAddress; use std::collections::HashMap; +use std::convert::TryInto; use std::fmt; -use ln_types::P2PAddress; +use std::sync::{Arc, Mutex}; #[macro_use] extern crate serde_derive; @@ -34,12 +34,10 @@ impl ScheduledChannel { let node = addr.parse().expect("invalid node address"); let amount = amount.to_str().expect("invalid channel amount"); - let amount = bitcoin::Amount::from_str_in(&amount, bitcoin::Denomination::Satoshi).expect("invalid channel amount"); + let amount = bitcoin::Amount::from_str_in(&amount, bitcoin::Denomination::Satoshi) + .expect("invalid channel amount"); - ScheduledChannel { - node, - amount, - } + ScheduledChannel { node, amount } } } @@ -53,9 +51,18 @@ struct ScheduledPayJoin { impl ScheduledPayJoin { fn total_amount(&self) -> bitcoin::Amount { - let fees = calculate_fees(self.channels.len() as u64, self.fee_rate, self.wallet_amount != bitcoin::Amount::ZERO); - - self.channels.iter().map(|channel| channel.amount).fold(bitcoin::Amount::ZERO, std::ops::Add::add) + self.wallet_amount + fees + let fees = calculate_fees( + self.channels.len() as u64, + self.fee_rate, + self.wallet_amount != bitcoin::Amount::ZERO, + ); + + self.channels + .iter() + .map(|channel| channel.amount) + .fold(bitcoin::Amount::ZERO, std::ops::Add::add) + + self.wallet_amount + + fees } async fn test_connections(&self, client: &mut tonic_lnd::Client) { @@ -76,7 +83,7 @@ impl PayJoins { Entry::Vacant(place) => { place.insert(payjoin); Ok(()) - }, + } Entry::Occupied(_) => Err(()), } } @@ -100,7 +107,8 @@ impl Handler { let mut iter = match version.find('-') { Some(pos) => &version[..pos], None => &version, - }.split('.'); + } + .split('.'); let major = iter.next().expect("split returns non-empty iterator").parse::(); let minor = iter.next().unwrap_or("0").parse::(); @@ -108,32 +116,31 @@ impl Handler { match (major, minor, patch) { (Ok(major), Ok(minor), Ok(patch)) => Ok(((major, minor, patch), version)), - (Err(error), _, _) => Err(CheckError::VersionNumber { version, error, }), - (_, Err(error), _) => Err(CheckError::VersionNumber { version, error, }), - (_, _, Err(error)) => Err(CheckError::VersionNumber { version, error, }), + (Err(error), _, _) => Err(CheckError::VersionNumber { version, error }), + (_, Err(error), _) => Err(CheckError::VersionNumber { version, error }), + (_, _, Err(error)) => Err(CheckError::VersionNumber { version, error }), } } async fn new(mut client: tonic_lnd::Client) -> Result { - let version = client.get_info(tonic_lnd::rpc::GetInfoRequest {}).await?.into_inner().version; + let version = + client.get_info(tonic_lnd::rpc::GetInfoRequest {}).await?.into_inner().version; let (parsed_version, version) = Self::parse_lnd_version(version)?; if parsed_version < (0, 14, 0) { return Err(CheckError::LNDTooOld(version)); } else if parsed_version < (0, 14, 2) { - eprintln!("WARNING: LND older than 0.14.2. Using with an empty LND wallet is impossible."); + eprintln!( + "WARNING: LND older than 0.14.2. Using with an empty LND wallet is impossible." + ); } - Ok(Handler { - client, - payjoins: Default::default(), - }) + Ok(Handler { client, payjoins: Default::default() }) } - } #[derive(Debug)] enum CheckError { RequestFailed(tonic_lnd::Error), - VersionNumber { version: String, error: std::num::ParseIntError, }, + VersionNumber { version: String, error: std::num::ParseIntError }, LNDTooOld(String), } @@ -141,8 +148,14 @@ impl fmt::Display for CheckError { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { CheckError::RequestFailed(_) => write!(f, "failed to get LND version"), - CheckError::VersionNumber { version, error: _, } => write!(f, "Unparsable LND version '{}'", version), - CheckError::LNDTooOld(version) => write!(f, "LND version {} is too old - it would cause GUARANTEED LOSS of sats!", version), + CheckError::VersionNumber { version, error: _ } => { + write!(f, "Unparsable LND version '{}'", version) + } + CheckError::LNDTooOld(version) => write!( + f, + "LND version {} is too old - it would cause GUARANTEED LOSS of sats!", + version + ), } } } @@ -151,7 +164,7 @@ impl std::error::Error for CheckError { fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { match self { CheckError::RequestFailed(error) => Some(error), - CheckError::VersionNumber { version: _, error, } => Some(error), + CheckError::VersionNumber { version: _, error } => Some(error), CheckError::LNDTooOld(_) => None, } } @@ -163,19 +176,13 @@ impl From for CheckError { } } - async fn ensure_connected(client: &mut tonic_lnd::Client, node: &P2PAddress) { let pubkey = node.node_id.to_string(); - let peer_addr = tonic_lnd::rpc::LightningAddress { - pubkey: pubkey, - host: node.as_host_port().to_string(), - }; + let peer_addr = + tonic_lnd::rpc::LightningAddress { pubkey: pubkey, host: node.as_host_port().to_string() }; - let connect_req = tonic_lnd::rpc::ConnectPeerRequest { - addr: Some(peer_addr), - perm: true, - timeout: 60, - }; + let connect_req = + tonic_lnd::rpc::ConnectPeerRequest { addr: Some(peer_addr), perm: true, timeout: 60 }; client.connect_peer(connect_req).await.map(drop).unwrap_or_else(|error| { if !error.message().starts_with("already connected to peer") { @@ -184,7 +191,11 @@ async fn ensure_connected(client: &mut tonic_lnd::Client, node: &P2PAddress) { }); } -fn calculate_fees(channel_count: u64, fee_rate: u64, has_additional_output: bool) -> bitcoin::Amount { +fn calculate_fees( + channel_count: u64, + fee_rate: u64, + has_additional_output: bool, +) -> bitcoin::Amount { let additional_vsize = if has_additional_output { channel_count * (8 + 1 + 1 + 32) } else { @@ -196,7 +207,7 @@ fn calculate_fees(channel_count: u64, fee_rate: u64, has_additional_output: bool async fn get_new_bech32_address(client: &mut tonic_lnd::Client) -> Address { client - .new_address(tonic_lnd::rpc::NewAddressRequest { r#type: 0, account: String::new(), }) + .new_address(tonic_lnd::rpc::NewAddressRequest { r#type: 0, account: String::new() }) .await .expect("failed to get chain address") .into_inner() @@ -207,11 +218,13 @@ async fn get_new_bech32_address(client: &mut tonic_lnd::Client) -> Address { #[tokio::main] async fn main() -> Result<(), Box> { - let (config, mut args) = Config::including_optional_config_files(std::iter::empty::<&str>()).unwrap_or_exit(); + let (config, mut args) = + Config::including_optional_config_files(std::iter::empty::<&str>()).unwrap_or_exit(); - let client = tonic_lnd::connect(config.lnd_address, &config.lnd_cert_path, &config.lnd_macaroon_path) - .await - .expect("failed to connect"); + let client = + tonic_lnd::connect(config.lnd_address, &config.lnd_cert_path, &config.lnd_macaroon_path) + .await + .expect("failed to connect"); let mut handler = Handler::new(client).await?; @@ -224,22 +237,33 @@ async fn main() -> Result<(), Box> { let mut wallet_amount = bitcoin::Amount::ZERO; while let Some(arg) = args.next() { match args.next() { - Some(channel_amount) => scheduled_channels.push(ScheduledChannel::from_args(arg, channel_amount)), - None => wallet_amount = bitcoin::Amount::from_str_in(arg.to_str().expect("wallet amount not UTF-8"), bitcoin::Denomination::Satoshi)?, + Some(channel_amount) => { + scheduled_channels.push(ScheduledChannel::from_args(arg, channel_amount)) + } + None => { + wallet_amount = bitcoin::Amount::from_str_in( + arg.to_str().expect("wallet amount not UTF-8"), + bitcoin::Denomination::Satoshi, + )? + } } } - let scheduled_payjoin = ScheduledPayJoin { - wallet_amount, - channels: scheduled_channels, - fee_rate, - }; + let scheduled_payjoin = + ScheduledPayJoin { wallet_amount, channels: scheduled_channels, fee_rate }; scheduled_payjoin.test_connections(&mut handler.client).await; - println!("bitcoin:{}?amount={}&pj=https://example.com/pj", address, scheduled_payjoin.total_amount().to_string_in(bitcoin::Denomination::Bitcoin)); + println!( + "bitcoin:{}?amount={}&pj=https://example.com/pj", + address, + scheduled_payjoin.total_amount().to_string_in(bitcoin::Denomination::Bitcoin) + ); - handler.payjoins.insert(&address, scheduled_payjoin).expect("New Handler is supposed to be empty"); + handler + .payjoins + .insert(&address, scheduled_payjoin) + .expect("New Handler is supposed to be empty"); } let addr = ([127, 0, 0, 1], config.bind_port).into(); @@ -248,8 +272,9 @@ async fn main() -> Result<(), Box> { let handler = handler.clone(); async move { - - Ok::<_, hyper::Error>(service_fn(move |request| handle_web_req(handler.clone(), request))) + Ok::<_, hyper::Error>(service_fn(move |request| { + handle_web_req(handler.clone(), request) + })) } }); @@ -262,21 +287,27 @@ async fn main() -> Result<(), Box> { Ok(()) } -async fn handle_web_req(mut handler: Handler, req: Request) -> Result, hyper::Error> { +async fn handle_web_req( + mut handler: Handler, + req: Request, +) -> Result, hyper::Error> { use bitcoin::consensus::{Decodable, Encodable}; use std::path::Path; match (req.method(), req.uri().path()) { (&Method::GET, "/pj") => { - let index = std::fs::read(Path::new(STATIC_DIR).join("index.html")).expect("can't open index"); + let index = + std::fs::read(Path::new(STATIC_DIR).join("index.html")).expect("can't open index"); Ok(Response::new(Body::from(index))) - }, + } (&Method::GET, path) if path.starts_with("/pj/static/") => { let directory_traversal_vulnerable_path = &path[("/pj/static/".len())..]; - let file = std::fs::read(Path::new(STATIC_DIR).join(directory_traversal_vulnerable_path)).expect("can't open static file"); + let file = + std::fs::read(Path::new(STATIC_DIR).join(directory_traversal_vulnerable_path)) + .expect("can't open static file"); Ok(Response::new(Body::from(file))) - }, + } (&Method::POST, "/pj") => { dbg!(req.uri().query()); @@ -305,13 +336,30 @@ async fn handle_web_req(mut handler: Handler, req: Request) -> Result()).collect::>(); + let (our_output, scheduled_payjoin) = handler + .payjoins + .find(&mut psbt.global.unsigned_tx.output) + .expect("the transaction doesn't contain our output"); + let total_channel_amount: bitcoin::Amount = scheduled_payjoin + .channels + .iter() + .map(|channel| channel.amount) + .fold(bitcoin::Amount::ZERO, std::ops::Add::add); + let fees = calculate_fees( + scheduled_payjoin.channels.len() as u64, + scheduled_payjoin.fee_rate, + scheduled_payjoin.wallet_amount != bitcoin::Amount::ZERO, + ); + + assert_eq!( + our_output.value, + (total_channel_amount + scheduled_payjoin.wallet_amount + fees).as_sat() + ); + + let chids = (0..scheduled_payjoin.channels.len()) + .into_iter() + .map(|_| rand::random::<[u8; 32]>()) + .collect::>(); // no collect() because of async let mut txouts = Vec::with_capacity(scheduled_payjoin.channels.len()); @@ -324,15 +372,17 @@ async fn handle_web_req(mut handler: Handler, req: Request) -> Result) -> Result { let mut bytes = &*ready.psbt; - let tx = PartiallySignedTransaction::consensus_decode(&mut bytes).unwrap(); + let tx = PartiallySignedTransaction::consensus_decode(&mut bytes) + .unwrap(); eprintln!("PSBT received from LND: {:#?}", tx); assert_eq!(tx.global.unsigned_tx.output.len(), 1); txouts.extend(tx.global.unsigned_tx.output); break; - }, + } // panic? x => panic!("Unexpected update {:?}", x), } @@ -393,35 +450,47 @@ async fn handle_web_req(mut handler: Handler, req: Request) -> Result { let bytes = hyper::body::to_bytes(req.into_body()).await?; - let request = serde_json::from_slice::(&bytes).expect("invalid request"); + let request = + serde_json::from_slice::(&bytes).expect("invalid request"); request.test_connections(&mut handler.client).await; let address = get_new_bech32_address(&mut handler.client).await; let total_amount = request.total_amount(); handler.payjoins.insert(&address, request).expect("address reuse"); - let uri = format!("bitcoin:{}?amount={}&pj=https://example.com/pj", address, total_amount.to_string_in(bitcoin::Denomination::Bitcoin)); + let uri = format!( + "bitcoin:{}?amount={}&pj=https://example.com/pj", + address, + total_amount.to_string_in(bitcoin::Denomination::Bitcoin) + ); let mut response = Response::new(Body::from(uri)); - response.headers_mut().insert(hyper::header::CONTENT_TYPE, "text/plain".parse().unwrap()); + response + .headers_mut() + .insert(hyper::header::CONTENT_TYPE, "text/plain".parse().unwrap()); Ok(response) - }, + } // Return the 404 Not Found for other routes. _ => {