From 557e800de03a85ee22b4ffa6e036617150463ed9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BF=97=E5=AE=87?= Date: Fri, 12 Aug 2022 18:29:29 +0800 Subject: [PATCH] WIP: Clean up argument parsing. --- src/main.rs | 378 ++++++++++++++++++++++++++++++++++++++-------------- 1 file changed, 277 insertions(+), 101 deletions(-) diff --git a/src/main.rs b/src/main.rs index 83f754c..883a805 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,13 +1,15 @@ -use hyper::service::{make_service_fn, service_fn}; -use hyper::{Body, Method, Request, Response, Server, StatusCode}; use bitcoin::util::address::Address; use bitcoin::util::psbt::PartiallySignedTransaction; -use bitcoin::{TxOut, Script}; +use bitcoin::{Script, TxOut}; +use hyper::service::{make_service_fn, service_fn}; +use hyper::{Body, Method, Request, Response, Server, StatusCode}; +use ln_types::P2PAddress; +use std::collections::{HashMap, VecDeque}; use std::convert::TryInto; -use std::sync::{Arc, Mutex}; -use std::collections::HashMap; +use std::ffi::OsString; use std::fmt; -use ln_types::P2PAddress; +use std::num::ParseIntError; +use std::sync::{Arc, Mutex}; #[macro_use] extern crate serde_derive; @@ -29,17 +31,15 @@ struct ScheduledChannel { } impl ScheduledChannel { - fn from_args(addr: std::ffi::OsString, amount: std::ffi::OsString) -> Self { - use configure_me::parse_arg::Arg; + fn from_args(addr: &str, amount: &str) -> Result { + let node = addr + .parse::() + .map_err(ArgError::InvalidNodeAddress)?; - let node = addr.parse().expect("invalid node address"); - let amount = amount.to_str().expect("invalid channel amount"); - let amount = bitcoin::Amount::from_str_in(&amount, bitcoin::Denomination::Satoshi).expect("invalid channel amount"); + let amount = bitcoin::Amount::from_str_in(amount, bitcoin::Denomination::Satoshi) + .map_err(ArgError::InvalidBitcoinAmount)?; - ScheduledChannel { - node, - amount, - } + Ok(Self { node, amount }) } } @@ -47,15 +47,24 @@ impl ScheduledChannel { struct ScheduledPayJoin { #[serde(with = "bitcoin::util::amount::serde::as_sat")] wallet_amount: bitcoin::Amount, - channels: Vec, + channels: VecDeque, fee_rate: u64, } impl ScheduledPayJoin { fn total_amount(&self) -> bitcoin::Amount { - let fees = calculate_fees(self.channels.len() as u64, self.fee_rate, self.wallet_amount != bitcoin::Amount::ZERO); - - self.channels.iter().map(|channel| channel.amount).fold(bitcoin::Amount::ZERO, std::ops::Add::add) + self.wallet_amount + fees + let fees = calculate_fees( + self.channels.len() as u64, + self.fee_rate, + self.wallet_amount != bitcoin::Amount::ZERO, + ); + + self.channels + .iter() + .map(|channel| channel.amount) + .fold(bitcoin::Amount::ZERO, std::ops::Add::add) + + self.wallet_amount + + fees } async fn test_connections(&self, client: &mut tonic_lnd::Client) { @@ -72,20 +81,27 @@ impl PayJoins { fn insert(&self, address: &Address, payjoin: ScheduledPayJoin) -> Result<(), ()> { use std::collections::hash_map::Entry; - match self.0.lock().expect("payjoins mutex poisoned").entry(address.script_pubkey()) { + match self + .0 + .lock() + .expect("payjoins mutex poisoned") + .entry(address.script_pubkey()) + { Entry::Vacant(place) => { place.insert(payjoin); Ok(()) - }, + } Entry::Occupied(_) => Err(()), } } fn find<'a>(&self, txouts: &'a mut [TxOut]) -> Option<(&'a mut TxOut, ScheduledPayJoin)> { let mut payjoins = self.0.lock().expect("payjoins mutex poisoned"); - txouts - .iter_mut() - .find_map(|txout| payjoins.remove(&txout.script_pubkey).map(|payjoin| (txout, payjoin))) + txouts.iter_mut().find_map(|txout| { + payjoins + .remove(&txout.script_pubkey) + .map(|payjoin| (txout, payjoin)) + }) } } @@ -100,40 +116,76 @@ impl Handler { let mut iter = match version.find('-') { Some(pos) => &version[..pos], None => &version, - }.split('.'); + } + .split('.'); - let major = iter.next().expect("split returns non-empty iterator").parse::(); + let major = iter + .next() + .expect("split returns non-empty iterator") + .parse::(); let minor = iter.next().unwrap_or("0").parse::(); let patch = iter.next().unwrap_or("0").parse::(); match (major, minor, patch) { (Ok(major), Ok(minor), Ok(patch)) => Ok(((major, minor, patch), version)), - (Err(error), _, _) => Err(CheckError::VersionNumber { version, error, }), - (_, Err(error), _) => Err(CheckError::VersionNumber { version, error, }), - (_, _, Err(error)) => Err(CheckError::VersionNumber { version, error, }), + (Err(error), _, _) => Err(CheckError::VersionNumber { version, error }), + (_, Err(error), _) => Err(CheckError::VersionNumber { version, error }), + (_, _, Err(error)) => Err(CheckError::VersionNumber { version, error }), } } async fn new(mut client: tonic_lnd::Client) -> Result { - let version = client.get_info(tonic_lnd::rpc::GetInfoRequest {}).await?.into_inner().version; + let version = client + .get_info(tonic_lnd::rpc::GetInfoRequest {}) + .await? + .into_inner() + .version; let (parsed_version, version) = Self::parse_lnd_version(version)?; if parsed_version < (0, 14, 0) { return Err(CheckError::LNDTooOld(version)); } else if parsed_version < (0, 14, 2) { - eprintln!("WARNING: LND older than 0.14.2. Using with an empty LND wallet is impossible."); + eprintln!( + "WARNING: LND older than 0.14.2. Using with an empty LND wallet is impossible." + ); } Ok(Handler { client, payjoins: Default::default(), }) } +} +/// CLI argument errors. +#[derive(Debug)] +enum ArgError { + /// Argument not UTF-8 + NotUTF8(OsString), + /// Parse feerate error + FeeRateError(ParseIntError), + /// Parse node address error + InvalidNodeAddress(ln_types::p2p_address::ParseError), + /// Parse bitcoin amount error + InvalidBitcoinAmount(bitcoin::util::amount::ParseAmountError), + /// Wallet amount error + InvalidWalletAmount(bitcoin::util::amount::ParseAmountError), +} + +impl fmt::Display for ArgError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + // TODO: Do this properly. + write!(f, "invalid arguments: {:?}", self) + } } +impl std::error::Error for ArgError {} + #[derive(Debug)] enum CheckError { RequestFailed(tonic_lnd::Error), - VersionNumber { version: String, error: std::num::ParseIntError, }, + VersionNumber { + version: String, + error: std::num::ParseIntError, + }, LNDTooOld(String), } @@ -141,8 +193,14 @@ impl fmt::Display for CheckError { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { CheckError::RequestFailed(_) => write!(f, "failed to get LND version"), - CheckError::VersionNumber { version, error: _, } => write!(f, "Unparsable LND version '{}'", version), - CheckError::LNDTooOld(version) => write!(f, "LND version {} is too old - it would cause GUARANTEED LOSS of sats!", version), + CheckError::VersionNumber { version, error: _ } => { + write!(f, "Unparsable LND version '{}'", version) + } + CheckError::LNDTooOld(version) => write!( + f, + "LND version {} is too old - it would cause GUARANTEED LOSS of sats!", + version + ), } } } @@ -151,7 +209,7 @@ impl std::error::Error for CheckError { fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { match self { CheckError::RequestFailed(error) => Some(error), - CheckError::VersionNumber { version: _, error, } => Some(error), + CheckError::VersionNumber { version: _, error } => Some(error), CheckError::LNDTooOld(_) => None, } } @@ -163,7 +221,6 @@ impl From for CheckError { } } - async fn ensure_connected(client: &mut tonic_lnd::Client, node: &P2PAddress) { let pubkey = node.node_id.to_string(); let peer_addr = tonic_lnd::rpc::LightningAddress { @@ -177,14 +234,22 @@ async fn ensure_connected(client: &mut tonic_lnd::Client, node: &P2PAddress) { timeout: 60, }; - client.connect_peer(connect_req).await.map(drop).unwrap_or_else(|error| { - if !error.message().starts_with("already connected to peer") { - panic!("failed to connect to peer {}: {:?}", node, error); - } - }); + client + .connect_peer(connect_req) + .await + .map(drop) + .unwrap_or_else(|error| { + if !error.message().starts_with("already connected to peer") { + panic!("failed to connect to peer {}: {:?}", node, error); + } + }); } -fn calculate_fees(channel_count: u64, fee_rate: u64, has_additional_output: bool) -> bitcoin::Amount { +fn calculate_fees( + channel_count: u64, + fee_rate: u64, + has_additional_output: bool, +) -> bitcoin::Amount { let additional_vsize = if has_additional_output { channel_count * (8 + 1 + 1 + 32) } else { @@ -196,7 +261,10 @@ fn calculate_fees(channel_count: u64, fee_rate: u64, has_additional_output: bool async fn get_new_bech32_address(client: &mut tonic_lnd::Client) -> Address { client - .new_address(tonic_lnd::rpc::NewAddressRequest { r#type: 0, account: String::new(), }) + .new_address(tonic_lnd::rpc::NewAddressRequest { + r#type: 0, + account: String::new(), + }) .await .expect("failed to get chain address") .into_inner() @@ -205,41 +273,93 @@ async fn get_new_bech32_address(client: &mut tonic_lnd::Client) -> Address { .expect("lnd returned invalid address") } -#[tokio::main] -async fn main() -> Result<(), Box> { - let (config, mut args) = Config::including_optional_config_files(std::iter::empty::<&str>()).unwrap_or_exit(); +fn process_args>( + args: A, +) -> Result, ArgError> { + // ensure all args are utf8 + let args = args + .map(|arg| arg.into_string()) + .collect::, _>>() + .map_err(ArgError::NotUTF8)?; + + // first argument is fee rate + let fee_rate = match args.get(0) { + Some(fee_rate_str) => fee_rate_str + .parse::() + .map_err(ArgError::FeeRateError)?, + None => return Ok(None), + }; - let client = tonic_lnd::connect(config.lnd_address, &config.lnd_cert_path, &config.lnd_macaroon_path) - .await - .expect("failed to connect"); + // the remaining argument is the wallet amount (if any) + let mut wallet_amount = None; + + // parse scheduled channel arguments: pairs of (addr, amount) + let channels = match args.get(1..) { + Some(args) => args + .chunks(2) + .filter_map(|chunk| match chunk.len() { + // parse scheduled channel arg: a pair of (addr, amount) + 2 => Some(ScheduledChannel::from_args(&chunk[0], &chunk[1])), // p2p_address, btc_amount + // parse wallet amount (optional) + 1 => { + wallet_amount = Some( + bitcoin::Amount::from_str_in( + chunk[0].as_str(), + bitcoin::Denomination::Satoshi, + ) + .map_err(ArgError::InvalidWalletAmount), + ); + None + } + unexpected_len => panic!("chunk len should only be 1 or 2, got {}", unexpected_len), + }) + .collect::, _>>()?, + None => return Ok(None), + }; - let mut handler = Handler::new(client).await?; + // process wallet amount + let wallet_amount = wallet_amount.transpose()?.unwrap_or(bitcoin::Amount::ZERO); - if let Some(fee_rate) = args.next() { - let fee_rate = fee_rate.into_string().expect("fee rate is not UTF-8").parse::()?; - let address = get_new_bech32_address(&mut handler.client).await; + Ok(Some(ScheduledPayJoin { + wallet_amount, + channels, + fee_rate, + })) +} - let mut args = args.fuse(); - let mut scheduled_channels = Vec::with_capacity(args.size_hint().0 / 2); - let mut wallet_amount = bitcoin::Amount::ZERO; - while let Some(arg) = args.next() { - match args.next() { - Some(channel_amount) => scheduled_channels.push(ScheduledChannel::from_args(arg, channel_amount)), - None => wallet_amount = bitcoin::Amount::from_str_in(arg.to_str().expect("wallet amount not UTF-8"), bitcoin::Denomination::Satoshi)?, - } - } +#[tokio::main] +async fn main() -> Result<(), Box> { + let (config, args) = + Config::including_optional_config_files(std::iter::empty::<&str>()).unwrap_or_exit(); + + let scheduled_pj = process_args(args).expect("failed to process args"); - let scheduled_payjoin = ScheduledPayJoin { - wallet_amount, - channels: scheduled_channels, - fee_rate, - }; + let client = tonic_lnd::connect( + config.lnd_address, + &config.lnd_cert_path, + &config.lnd_macaroon_path, + ) + .await + .expect("failed to connect"); - scheduled_payjoin.test_connections(&mut handler.client).await; + let mut handler = Handler::new(client).await?; - println!("bitcoin:{}?amount={}&pj=https://example.com/pj", address, scheduled_payjoin.total_amount().to_string_in(bitcoin::Denomination::Bitcoin)); + if let Some(payjoin) = scheduled_pj { + payjoin.test_connections(&mut handler.client).await; + let address = get_new_bech32_address(&mut handler.client).await; - handler.payjoins.insert(&address, scheduled_payjoin).expect("New Handler is supposed to be empty"); + println!( + "bitcoin:{}?amount={}&pj=https://example.com/pj", + address, + payjoin + .total_amount() + .to_string_in(bitcoin::Denomination::Bitcoin) + ); + + handler + .payjoins + .insert(&address, payjoin) + .expect("New Handler is supposed to be empty"); } let addr = ([127, 0, 0, 1], config.bind_port).into(); @@ -248,8 +368,9 @@ async fn main() -> Result<(), Box> { let handler = handler.clone(); async move { - - Ok::<_, hyper::Error>(service_fn(move |request| handle_web_req(handler.clone(), request))) + Ok::<_, hyper::Error>(service_fn(move |request| { + handle_web_req(handler.clone(), request) + })) } }); @@ -262,21 +383,27 @@ async fn main() -> Result<(), Box> { Ok(()) } -async fn handle_web_req(mut handler: Handler, req: Request) -> Result, hyper::Error> { +async fn handle_web_req( + mut handler: Handler, + req: Request, +) -> Result, hyper::Error> { use bitcoin::consensus::{Decodable, Encodable}; use std::path::Path; match (req.method(), req.uri().path()) { (&Method::GET, "/pj") => { - let index = std::fs::read(Path::new(STATIC_DIR).join("index.html")).expect("can't open index"); + let index = + std::fs::read(Path::new(STATIC_DIR).join("index.html")).expect("can't open index"); Ok(Response::new(Body::from(index))) - }, + } (&Method::GET, path) if path.starts_with("/pj/static/") => { let directory_traversal_vulnerable_path = &path[("/pj/static/".len())..]; - let file = std::fs::read(Path::new(STATIC_DIR).join(directory_traversal_vulnerable_path)).expect("can't open static file"); + let file = + std::fs::read(Path::new(STATIC_DIR).join(directory_traversal_vulnerable_path)) + .expect("can't open static file"); Ok(Response::new(Body::from(file))) - }, + } (&Method::POST, "/pj") => { dbg!(req.uri().query()); @@ -305,13 +432,30 @@ async fn handle_web_req(mut handler: Handler, req: Request) -> Result()).collect::>(); + let (our_output, scheduled_payjoin) = handler + .payjoins + .find(&mut psbt.global.unsigned_tx.output) + .expect("the transaction doesn't contain our output"); + let total_channel_amount: bitcoin::Amount = scheduled_payjoin + .channels + .iter() + .map(|channel| channel.amount) + .fold(bitcoin::Amount::ZERO, std::ops::Add::add); + let fees = calculate_fees( + scheduled_payjoin.channels.len() as u64, + scheduled_payjoin.fee_rate, + scheduled_payjoin.wallet_amount != bitcoin::Amount::ZERO, + ); + + assert_eq!( + our_output.value, + (total_channel_amount + scheduled_payjoin.wallet_amount + fees).as_sat() + ); + + let chids = (0..scheduled_payjoin.channels.len()) + .into_iter() + .map(|_| rand::random::<[u8; 32]>()) + .collect::>(); // no collect() because of async let mut txouts = Vec::with_capacity(scheduled_payjoin.channels.len()); @@ -332,7 +476,11 @@ async fn handle_web_req(mut handler: Handler, req: Request) -> Result) -> Result { let mut bytes = &*ready.psbt; - let tx = PartiallySignedTransaction::consensus_decode(&mut bytes).unwrap(); + let tx = PartiallySignedTransaction::consensus_decode(&mut bytes) + .unwrap(); eprintln!("PSBT received from LND: {:#?}", tx); assert_eq!(tx.global.unsigned_tx.output.len(), 1); txouts.extend(tx.global.unsigned_tx.output); break; - }, + } // panic? x => panic!("Unexpected update {:?}", x), } @@ -371,7 +528,10 @@ async fn handle_web_req(mut handler: Handler, req: Request) -> Result) -> Result) -> Result { let bytes = hyper::body::to_bytes(req.into_body()).await?; - let request = serde_json::from_slice::(&bytes).expect("invalid request"); + let request = + serde_json::from_slice::(&bytes).expect("invalid request"); request.test_connections(&mut handler.client).await; let address = get_new_bech32_address(&mut handler.client).await; let total_amount = request.total_amount(); - handler.payjoins.insert(&address, request).expect("address reuse"); - - let uri = format!("bitcoin:{}?amount={}&pj=https://example.com/pj", address, total_amount.to_string_in(bitcoin::Denomination::Bitcoin)); + handler + .payjoins + .insert(&address, request) + .expect("address reuse"); + + let uri = format!( + "bitcoin:{}?amount={}&pj=https://example.com/pj", + address, + total_amount.to_string_in(bitcoin::Denomination::Bitcoin) + ); let mut response = Response::new(Body::from(uri)); - response.headers_mut().insert(hyper::header::CONTENT_TYPE, "text/plain".parse().unwrap()); + response + .headers_mut() + .insert(hyper::header::CONTENT_TYPE, "text/plain".parse().unwrap()); Ok(response) - }, + } // Return the 404 Not Found for other routes. _ => {