Skip to content

Commit

Permalink
Fix request_port_forward to be reenterant
Browse files Browse the repository at this point in the history
  • Loading branch information
pyhedgehog committed Apr 9, 2024
1 parent 51eb55d commit f7a5313
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 16 deletions.
13 changes: 8 additions & 5 deletions paramiko/transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,7 +515,7 @@ def __init__(
self.default_window_size = default_window_size
self._forward_agent_handler = None
self._x11_handler = None
self._tcp_handler = None
self._tcp_handlers = {}

self.saved_exception = None
self.clear_to_send = threading.Event()
Expand Down Expand Up @@ -1135,6 +1135,8 @@ def request_port_forward(self, address, port, handler=None):
"""
if not self.active:
raise SSHException("SSH session not active")
if (address, port) in self._tcp_handlers:
raise SSHException("TCP forwarding port already used")
port = int(port)
response = self.global_request(
"tcpip-forward", (address, port), wait=True
Expand All @@ -1151,7 +1153,7 @@ def default_handler(channel, src_addr, dest_addr_port):
self._queue_incoming_channel(channel)

handler = default_handler
self._tcp_handler = handler
self._tcp_handlers[address, port] = handler
return port

def cancel_port_forward(self, address, port):
Expand All @@ -1165,7 +1167,8 @@ def cancel_port_forward(self, address, port):
"""
if not self.active:
return
self._tcp_handler = None
if (address, port) in self._tcp_handlers:
del self._tcp_handlers[address, port]
self.global_request("cancel-tcpip-forward", (address, port), wait=True)

def open_sftp_client(self):
Expand Down Expand Up @@ -2979,7 +2982,7 @@ def _parse_channel_open(self, m):
my_chanid = self._next_channel()
finally:
self.lock.release()
elif (kind == "forwarded-tcpip") and (self._tcp_handler is not None):
elif (kind == "forwarded-tcpip") and (len(self._tcp_handlers) > 0):
server_addr = m.get_text()
server_port = m.get_int()
origin_addr = m.get_text()
Expand Down Expand Up @@ -3069,7 +3072,7 @@ def _parse_channel_open(self, m):
self._x11_handler(chan, (origin_addr, origin_port))
elif kind == "forwarded-tcpip":
chan.origin_addr = (origin_addr, origin_port)
self._tcp_handler(
self._tcp_handlers[server_addr, server_port](
chan, (origin_addr, origin_port), (server_addr, server_port)
)
else:
Expand Down
16 changes: 10 additions & 6 deletions tests/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,7 @@ class TestServer(ServerInterface):

def __init__(self, allowed_keys=None):
self.allowed_keys = allowed_keys if allowed_keys is not None else []
self._listen = {}

def check_channel_request(self, kind, chanid):
if kind == "bogus":
Expand Down Expand Up @@ -253,14 +254,17 @@ def check_channel_x11_request(
return True

def check_port_forward_request(self, addr, port):
self._listen = socket.socket()
self._listen.bind(("127.0.0.1", 0))
self._listen.listen(1)
return self._listen.getsockname()[1]
assert (addr, port) not in self._listen
listen = socket.socket()
listen.bind((addr, port))
listen.listen(1)
port = listen.getsockname()[1]
self._listen[addr, port] = listen
return port

def cancel_port_forward_request(self, addr, port):
self._listen.close()
self._listen = None
self._listen[addr, port].close()
del self._listen[addr, port]

def check_channel_direct_tcpip_request(self, chanid, origin, destination):
self._tcpip_dest = destination
Expand Down
82 changes: 77 additions & 5 deletions tests/test_transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

from binascii import hexlify
import itertools
import functools
import select
import socket
import time
Expand Down Expand Up @@ -529,13 +530,14 @@ def handler(c, origin_addr_port, server_addr_port):
self.tc._queue_incoming_channel(c)

port = self.tc.request_port_forward("127.0.0.1", 0, handler)
self.assertEqual(port, self.server._listen.getsockname()[1])
key = ("127.0.0.1", port)
self.assertEqual(port, self.server._listen[key].getsockname()[1])

cs = socket.socket()
cs.connect(("127.0.0.1", port))
ss, _ = self.server._listen.accept()
ss, _ = self.server._listen[key].accept()
sch = self.ts.open_forwarded_tcpip_channel(
ss.getsockname(), ss.getpeername()
ss.getpeername(), ss.getsockname()
)
cch = self.tc.accept()

Expand All @@ -548,7 +550,74 @@ def handler(c, origin_addr_port, server_addr_port):

# now cancel it.
self.tc.cancel_port_forward("127.0.0.1", port)
self.assertTrue(self.server._listen is None)
self.assertTrue(not self.server._listen)

def test_reverse_port_forwarding_twice(self):
"""
verify that a client can ask the server to open a reverse port for
forwarding.
"""
self.setup_test_server()
chan = self.tc.open_session()
chan.exec_command("yes")
self.ts.accept(1.0)

requested = []

def handler(c, origin_addr_port, server_addr_port, mark=None):
requested.append(server_addr_port)
requested.append(mark)
self.tc._queue_incoming_channel(c)

def process_port(port, check_key=[1]):
cs = socket.socket()
cs.connect(("127.0.0.1", port))
ss, _ = self.server._listen["127.0.0.1", port].accept()
sch = self.ts.open_forwarded_tcpip_channel(
ss.getpeername(), ss.getsockname()
)
cch = self.tc.accept()

check = b"hello%d" % (check_key[0],)
check_key[0] += 1
sch.send(check)
self.assertEqual(check, cch.recv(6))
sch.close()
cch.close()
ss.close()
cs.close()

port1 = self.tc.request_port_forward(
"127.0.0.1", 0, functools.partial(handler, mark="port1")
)
self.assertTrue(("127.0.0.1", port1) in self.server._listen)
process_port(port1)

port2 = self.tc.request_port_forward(
"127.0.0.1", 0, functools.partial(handler, mark="port2")
)
self.assertTrue(("127.0.0.1", port2) in self.server._listen)
process_port(port2)
process_port(port1)

# Split checks to see what step was failed
self.assertEqual(len(requested), 6)
self.assertEqual(
requested[:2] + ["assert1"],
[("127.0.0.1", port1), "port1", "assert1"],
)
self.assertEqual(
requested[2:4] + ["assert2"],
[("127.0.0.1", port2), "port2", "assert2"],
)
self.assertEqual(
requested[4:] + ["assert3"],
[("127.0.0.1", port1), "port1", "assert3"],
)
# now cancel it.
self.tc.cancel_port_forward("127.0.0.1", port1)
self.tc.cancel_port_forward("127.0.0.1", port2)
self.assertTrue(not self.server._listen)

def test_port_forwarding(self):
"""
Expand Down Expand Up @@ -1187,7 +1256,10 @@ def test_kex_with_sha2_256(self):
# No 512 -> you get 256
with server(
init=dict(disabled_algorithms=dict(keys=["rsa-sha2-512"]))
) as (tc, _):
) as (
tc,
_,
):
assert tc.host_key_type == "rsa-sha2-256"

def _incompatible_peers(self, client_init, server_init):
Expand Down

0 comments on commit f7a5313

Please sign in to comment.