Skip to content

Commit

Permalink
Merge #579
Browse files Browse the repository at this point in the history
579: Remove IpRepr::Unspecified r=Dirbaio a=Dirbaio

Implement simplifications unlocked by adding the `Context` parameter to socket methods. 

- tcp: immediately choose src addr on connect. Now that we have access to `Context` from `.connect()`, we don't have to delay setting it anymore.
- Remove IpRepr::Unspecified and lowering. Choosing the right source IP address is now responsibility of the individual sockets.



Co-authored-by: Dario Nieuwenhuis <dirbaio@dirbaio.net>
  • Loading branch information
bors[bot] and Dirbaio committed Mar 20, 2022
2 parents d05ebdd + 5989896 commit feedb3f
Show file tree
Hide file tree
Showing 13 changed files with 338 additions and 624 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Expand Up @@ -22,6 +22,7 @@ log = { version = "0.4.4", default-features = false, optional = true }
libc = { version = "0.2.18", optional = true }
bitflags = { version = "1.0", default-features = false }
defmt = { version = "0.3", optional = true }
cfg-if = "1.0.0"

[dev-dependencies]
env_logger = "0.9"
Expand Down
2 changes: 1 addition & 1 deletion benches/bench.rs
Expand Up @@ -85,7 +85,7 @@ mod wire {
let repr = Ipv4Repr {
src_addr: Ipv4Address([192, 168, 1, 1]),
dst_addr: Ipv4Address([192, 168, 1, 2]),
protocol: IpProtocol::Tcp,
next_header: IpProtocol::Tcp,
payload_len: 100,
hop_limit: 64,
};
Expand Down
94 changes: 60 additions & 34 deletions src/iface/interface.rs
Expand Up @@ -923,7 +923,8 @@ where
(IpRepr::Ipv6(ipv6_repr), IcmpRepr::Ipv6(icmpv6_repr)) => {
respond!(inner, IpPacket::Icmpv6((ipv6_repr, icmpv6_repr)))
}
_ => Err(Error::Unaddressable),
#[allow(unreachable_patterns)]
_ => unreachable!(),
}),
#[cfg(feature = "socket-udp")]
Socket::Udp(socket) => socket.dispatch(inner, |inner, response| {
Expand Down Expand Up @@ -1055,6 +1056,18 @@ impl<'a> InterfaceInner<'a> {
&mut self.rand
}

#[allow(unused)] // unused depending on which sockets are enabled
pub(crate) fn get_source_address(&mut self, dst_addr: IpAddress) -> Option<IpAddress> {
let v = dst_addr.version().unwrap();
for cidr in self.ip_addrs.iter() {
let addr = cidr.address();
if addr.version() == Some(v) {
return Some(addr);
}
}
None
}

#[cfg(test)]
pub(crate) fn mock() -> Self {
Self {
Expand All @@ -1080,7 +1093,15 @@ impl<'a> InterfaceInner<'a> {
},
now: Instant::from_millis_const(0),

ip_addrs: ManagedSlice::Owned(vec![]),
ip_addrs: ManagedSlice::Owned(vec![
#[cfg(feature = "proto-ipv4")]
IpCidr::Ipv4(Ipv4Cidr::new(Ipv4Address::new(192, 168, 1, 1), 24)),
#[cfg(feature = "proto-ipv6")]
IpCidr::Ipv6(Ipv6Cidr::new(
Ipv6Address([0xfe, 0x80, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]),
64,
)),
]),
rand: Rand::new(1234),
routes: Routes::new(&mut [][..]),

Expand Down Expand Up @@ -1572,7 +1593,7 @@ impl<'a> InterfaceInner<'a> {

#[cfg(feature = "socket-dhcpv4")]
{
if ipv4_repr.protocol == IpProtocol::Udp && self.hardware_addr.is_some() {
if ipv4_repr.next_header == IpProtocol::Udp && self.hardware_addr.is_some() {
// First check for source and dest ports, then do `UdpRepr::parse` if they match.
// This way we avoid validating the UDP checksum twice for all non-DHCP UDP packets (one here, one in `process_udp`)
let udp_packet = UdpPacket::new_checked(ip_payload)?;
Expand Down Expand Up @@ -1617,7 +1638,7 @@ impl<'a> InterfaceInner<'a> {
}
}

match ipv4_repr.protocol {
match ipv4_repr.next_header {
IpProtocol::Icmp => self.process_icmpv4(sockets, ip_repr, ip_payload),

#[cfg(feature = "proto-igmp")]
Expand Down Expand Up @@ -1791,7 +1812,8 @@ impl<'a> InterfaceInner<'a> {
};
Ok(self.icmpv6_reply(ipv6_repr, icmp_reply_repr))
}
_ => Err(Error::Unrecognized),
#[allow(unreachable_patterns)]
_ => unreachable!(),
},

// Ignore any echo replies.
Expand All @@ -1801,7 +1823,8 @@ impl<'a> InterfaceInner<'a> {
#[cfg(any(feature = "medium-ethernet", feature = "medium-ieee802154"))]
Icmpv6Repr::Ndisc(repr) if ip_repr.hop_limit() == 0xff => match ip_repr {
IpRepr::Ipv6(ipv6_repr) => self.process_ndisc(ipv6_repr, repr),
_ => Ok(None),
#[allow(unreachable_patterns)]
_ => unreachable!(),
},

// Don't report an error if a packet with unknown type
Expand Down Expand Up @@ -1976,7 +1999,8 @@ impl<'a> InterfaceInner<'a> {
};
match ip_repr {
IpRepr::Ipv4(ipv4_repr) => Ok(self.icmpv4_reply(ipv4_repr, icmp_reply_repr)),
_ => Err(Error::Unrecognized),
#[allow(unreachable_patterns)]
_ => unreachable!(),
}
}

Expand Down Expand Up @@ -2007,7 +2031,7 @@ impl<'a> InterfaceInner<'a> {
let ipv4_reply_repr = Ipv4Repr {
src_addr: ipv4_repr.dst_addr,
dst_addr: ipv4_repr.src_addr,
protocol: IpProtocol::Icmp,
next_header: IpProtocol::Icmp,
payload_len: icmp_repr.buffer_len(),
hop_limit: 64,
};
Expand All @@ -2020,7 +2044,7 @@ impl<'a> InterfaceInner<'a> {
let ipv4_reply_repr = Ipv4Repr {
src_addr: src_addr,
dst_addr: ipv4_repr.src_addr,
protocol: IpProtocol::Icmp,
next_header: IpProtocol::Icmp,
payload_len: icmp_repr.buffer_len(),
hop_limit: 64,
};
Expand Down Expand Up @@ -2113,7 +2137,6 @@ impl<'a> InterfaceInner<'a> {
};
Ok(self.icmpv6_reply(ipv6_repr, icmpv6_reply_repr))
}
IpRepr::Unspecified { .. } => Err(Error::Unaddressable),
}
}

Expand Down Expand Up @@ -2390,7 +2413,9 @@ impl<'a> InterfaceInner<'a> {
}

fn dispatch_ip<Tx: TxToken>(&mut self, tx_token: Tx, packet: IpPacket) -> Result<()> {
let ip_repr = packet.ip_repr().lower(&self.ip_addrs)?;
let ip_repr = packet.ip_repr();
assert!(!ip_repr.src_addr().is_unspecified());
assert!(!ip_repr.dst_addr().is_unspecified());

match self.caps.medium {
#[cfg(feature = "medium-ethernet")]
Expand All @@ -2413,7 +2438,6 @@ impl<'a> InterfaceInner<'a> {
IpRepr::Ipv4(_) => frame.set_ethertype(EthernetProtocol::Ipv4),
#[cfg(feature = "proto-ipv6")]
IpRepr::Ipv6(_) => frame.set_ethertype(EthernetProtocol::Ipv6),
_ => return,
}

ip_repr.emit(frame.payload_mut(), &caps.checksum);
Expand Down Expand Up @@ -2443,7 +2467,9 @@ impl<'a> InterfaceInner<'a> {

#[cfg(feature = "medium-ieee802154")]
fn dispatch_ieee802154<Tx: TxToken>(&mut self, tx_token: Tx, packet: IpPacket) -> Result<()> {
let ip_repr = packet.ip_repr().lower(&self.ip_addrs)?;
let ip_repr = packet.ip_repr();
assert!(!ip_repr.src_addr().is_unspecified());
assert!(!ip_repr.dst_addr().is_unspecified());

match self.caps.medium {
#[cfg(feature = "medium-ieee802154")]
Expand Down Expand Up @@ -2599,7 +2625,7 @@ impl<'a> InterfaceInner<'a> {
src_addr: iface_addr,
// Send to the group being reported
dst_addr: group_addr,
protocol: IpProtocol::Igmp,
next_header: IpProtocol::Igmp,
payload_len: igmp_repr.buffer_len(),
hop_limit: 1,
// TODO: add Router Alert IPv4 header option. See
Expand All @@ -2618,7 +2644,7 @@ impl<'a> InterfaceInner<'a> {
Ipv4Repr {
src_addr: iface_addr,
dst_addr: Ipv4Address::MULTICAST_ALL_ROUTERS,
protocol: IpProtocol::Igmp,
next_header: IpProtocol::Igmp,
payload_len: igmp_repr.buffer_len(),
hop_limit: 1,
},
Expand Down Expand Up @@ -2746,7 +2772,7 @@ mod test {
let repr = IpRepr::Ipv4(Ipv4Repr {
src_addr: Ipv4Address([0x7f, 0x00, 0x00, 0x01]),
dst_addr: Ipv4Address::BROADCAST,
protocol: IpProtocol::Unknown(0x0c),
next_header: IpProtocol::Unknown(0x0c),
payload_len: 0,
hop_limit: 0x40,
});
Expand Down Expand Up @@ -2805,7 +2831,7 @@ mod test {
let repr = IpRepr::Ipv4(Ipv4Repr {
src_addr: Ipv4Address([0x7f, 0x00, 0x00, 0x02]),
dst_addr: Ipv4Address([0x7f, 0x00, 0x00, 0x01]),
protocol: IpProtocol::Unknown(0x0c),
next_header: IpProtocol::Unknown(0x0c),
payload_len: 0,
hop_limit: 0x40,
});
Expand All @@ -2821,7 +2847,7 @@ mod test {
header: Ipv4Repr {
src_addr: Ipv4Address([0x7f, 0x00, 0x00, 0x02]),
dst_addr: Ipv4Address([0x7f, 0x00, 0x00, 0x01]),
protocol: IpProtocol::Unknown(12),
next_header: IpProtocol::Unknown(12),
payload_len: 0,
hop_limit: 64,
},
Expand All @@ -2832,7 +2858,7 @@ mod test {
Ipv4Repr {
src_addr: Ipv4Address([0x7f, 0x00, 0x00, 0x01]),
dst_addr: Ipv4Address([0x7f, 0x00, 0x00, 0x02]),
protocol: IpProtocol::Icmp,
next_header: IpProtocol::Icmp,
payload_len: icmp_repr.buffer_len(),
hop_limit: 64,
},
Expand Down Expand Up @@ -2922,7 +2948,7 @@ mod test {
let ip_repr = IpRepr::Ipv4(Ipv4Repr {
src_addr: Ipv4Address([0x7f, 0x00, 0x00, 0x02]),
dst_addr: Ipv4Address([0x7f, 0x00, 0x00, 0x01]),
protocol: IpProtocol::Udp,
next_header: IpProtocol::Udp,
payload_len: udp_repr.header_len() + UDP_PAYLOAD.len(),
hop_limit: 64,
});
Expand All @@ -2946,7 +2972,7 @@ mod test {
header: Ipv4Repr {
src_addr: Ipv4Address([0x7f, 0x00, 0x00, 0x02]),
dst_addr: Ipv4Address([0x7f, 0x00, 0x00, 0x01]),
protocol: IpProtocol::Udp,
next_header: IpProtocol::Udp,
payload_len: udp_repr.header_len() + UDP_PAYLOAD.len(),
hop_limit: 64,
},
Expand All @@ -2956,7 +2982,7 @@ mod test {
Ipv4Repr {
src_addr: Ipv4Address([0x7f, 0x00, 0x00, 0x01]),
dst_addr: Ipv4Address([0x7f, 0x00, 0x00, 0x02]),
protocol: IpProtocol::Icmp,
next_header: IpProtocol::Icmp,
payload_len: icmp_repr.buffer_len(),
hop_limit: 64,
},
Expand All @@ -2975,7 +3001,7 @@ mod test {
let ip_repr = IpRepr::Ipv4(Ipv4Repr {
src_addr: Ipv4Address([0x7f, 0x00, 0x00, 0x02]),
dst_addr: Ipv4Address::BROADCAST,
protocol: IpProtocol::Udp,
next_header: IpProtocol::Udp,
payload_len: udp_repr.header_len() + UDP_PAYLOAD.len(),
hop_limit: 64,
});
Expand Down Expand Up @@ -3046,7 +3072,7 @@ mod test {
let ip_repr = IpRepr::Ipv4(Ipv4Repr {
src_addr: src_ip,
dst_addr: Ipv4Address::BROADCAST,
protocol: IpProtocol::Udp,
next_header: IpProtocol::Udp,
payload_len: udp_repr.header_len() + UDP_PAYLOAD.len(),
hop_limit: 0x40,
});
Expand Down Expand Up @@ -3106,7 +3132,7 @@ mod test {
let ipv4_repr = Ipv4Repr {
src_addr: src_ipv4_addr,
dst_addr: Ipv4Address::BROADCAST,
protocol: IpProtocol::Icmp,
next_header: IpProtocol::Icmp,
hop_limit: 64,
payload_len: icmpv4_repr.buffer_len(),
};
Expand Down Expand Up @@ -3134,7 +3160,7 @@ mod test {
let expected_ipv4_repr = Ipv4Repr {
src_addr: our_ipv4_addr,
dst_addr: src_ipv4_addr,
protocol: IpProtocol::Icmp,
next_header: IpProtocol::Icmp,
hop_limit: 64,
payload_len: expected_icmpv4_repr.buffer_len(),
};
Expand Down Expand Up @@ -3192,7 +3218,7 @@ mod test {
let ip_repr = Ipv4Repr {
src_addr: src_addr,
dst_addr: dst_addr,
protocol: IpProtocol::Udp,
next_header: IpProtocol::Udp,
hop_limit: 64,
payload_len: udp_repr.header_len() + MAX_PAYLOAD_LEN,
};
Expand Down Expand Up @@ -3231,7 +3257,7 @@ mod test {
let expected_ip_repr = Ipv4Repr {
src_addr: dst_addr,
dst_addr: src_addr,
protocol: IpProtocol::Icmp,
next_header: IpProtocol::Icmp,
hop_limit: 64,
payload_len: expected_icmp_repr.buffer_len(),
};
Expand Down Expand Up @@ -3541,7 +3567,7 @@ mod test {
let ipv4_repr = Ipv4Repr {
src_addr: Ipv4Address::new(0x7f, 0x00, 0x00, 0x02),
dst_addr: Ipv4Address::new(0x7f, 0x00, 0x00, 0x01),
protocol: IpProtocol::Icmp,
next_header: IpProtocol::Icmp,
payload_len: 24,
hop_limit: 64,
};
Expand Down Expand Up @@ -3714,7 +3740,7 @@ mod test {
let reports = recv_igmp(&mut iface, timestamp);
assert_eq!(reports.len(), 2);
for (i, group_addr) in groups.iter().enumerate() {
assert_eq!(reports[i].0.protocol, IpProtocol::Igmp);
assert_eq!(reports[i].0.next_header, IpProtocol::Igmp);
assert_eq!(reports[i].0.dst_addr, *group_addr);
assert_eq!(
reports[i].1,
Expand Down Expand Up @@ -3758,7 +3784,7 @@ mod test {
let leaves = recv_igmp(&mut iface, timestamp);
assert_eq!(leaves.len(), 2);
for (i, group_addr) in groups.iter().cloned().enumerate() {
assert_eq!(leaves[i].0.protocol, IpProtocol::Igmp);
assert_eq!(leaves[i].0.next_header, IpProtocol::Igmp);
assert_eq!(leaves[i].0.dst_addr, Ipv4Address::MULTICAST_ALL_ROUTERS);
assert_eq!(leaves[i].1, IgmpRepr::LeaveGroup { group_addr });
}
Expand Down Expand Up @@ -3804,7 +3830,7 @@ mod test {
let ipv4_repr = Ipv4Repr {
src_addr: src_addr,
dst_addr: dst_addr,
protocol: IpProtocol::Udp,
next_header: IpProtocol::Udp,
hop_limit: 64,
payload_len: udp_repr.header_len() + PAYLOAD_LEN,
};
Expand Down Expand Up @@ -3874,7 +3900,7 @@ mod test {
let ipv4_repr = Ipv4Repr {
src_addr: src_addr,
dst_addr: dst_addr,
protocol: IpProtocol::Udp,
next_header: IpProtocol::Udp,
hop_limit: 64,
payload_len: udp_repr.header_len() + PAYLOAD_LEN,
};
Expand Down Expand Up @@ -3965,7 +3991,7 @@ mod test {
let ipv4_repr = Ipv4Repr {
src_addr: src_addr,
dst_addr: dst_addr,
protocol: IpProtocol::Udp,
next_header: IpProtocol::Udp,
hop_limit: 64,
payload_len: udp_repr.header_len() + UDP_PAYLOAD.len(),
};
Expand Down

0 comments on commit feedb3f

Please sign in to comment.