Skip to content

Commit

Permalink
9713 conch types (#12142)
Browse files Browse the repository at this point in the history
  • Loading branch information
glyph committed May 1, 2024
2 parents 02a2b65 + 9da1372 commit ac0d3a5
Show file tree
Hide file tree
Showing 7 changed files with 147 additions and 161 deletions.
98 changes: 50 additions & 48 deletions src/twisted/conch/client/knownhosts.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,25 @@
@since: 8.2
"""

from __future__ import annotations

import hmac
import sys
from binascii import Error as DecodeError, a2b_base64, b2a_base64
from contextlib import closing
from hashlib import sha1
from typing import IO, Callable, Literal

from zope.interface import implementer

from twisted.conch.error import HostKeyChanged, InvalidEntry, UserRejectedKey
from twisted.conch.interfaces import IKnownHostEntry
from twisted.conch.ssh.keys import BadKeyError, FingerprintFormats, Key
from twisted.internet import defer
from twisted.internet.defer import Deferred
from twisted.logger import Logger
from twisted.python.compat import nativeString
from twisted.python.filepath import FilePath
from twisted.python.randbytes import secureRandom
from twisted.python.util import FancyEqMixin

Expand Down Expand Up @@ -111,65 +115,61 @@ class PlainEntry(_BaseEntry):
file.
@ivar _hostnames: the list of all host-names associated with this entry.
@type _hostnames: L{list} of L{bytes}
"""

def __init__(self, hostnames, keyType, publicKey, comment):
self._hostnames = hostnames
def __init__(
self, hostnames: list[bytes], keyType: bytes, publicKey: Key, comment: bytes
):
self._hostnames: list[bytes] = hostnames
super().__init__(keyType, publicKey, comment)

@classmethod
def fromString(cls, string):
def fromString(cls, string: bytes) -> PlainEntry:
"""
Parse a plain-text entry in a known_hosts file, and return a
corresponding L{PlainEntry}.
@param string: a space-separated string formatted like "hostname
key-type base64-key-data comment".
@type string: L{bytes}
key-type base64-key-data comment".
@raise DecodeError: if the key is not valid encoded as valid base64.
@raise InvalidEntry: if the entry does not have the right number of
elements and is therefore invalid.
elements and is therefore invalid.
@raise BadKeyError: if the key, once decoded from base64, is not
actually an SSH key.
actually an SSH key.
@return: an IKnownHostEntry representing the hostname and key in the
input line.
input line.
@rtype: L{PlainEntry}
"""
hostnames, keyType, key, comment = _extractCommon(string)
self = cls(hostnames.split(b","), keyType, key, comment)
return self

def matchesHost(self, hostname):
def matchesHost(self, hostname: bytes | str) -> bool:
"""
Check to see if this entry matches a given hostname.
@param hostname: A hostname or IP address literal to check against this
entry.
@type hostname: L{bytes}
@return: C{True} if this entry is for the given hostname or IP address,
C{False} otherwise.
@rtype: L{bool}
"""
if isinstance(hostname, str):
hostname = hostname.encode("utf-8")
return hostname in self._hostnames

def toString(self):
def toString(self) -> bytes:
"""
Implement L{IKnownHostEntry.toString} by recording the comma-separated
hostnames, key type, and base-64 encoded key.
@return: The string representation of this entry, with unhashed hostname
information.
@rtype: L{bytes}
"""
fields = [
b",".join(self._hostnames),
Expand Down Expand Up @@ -256,33 +256,39 @@ class HashedEntry(_BaseEntry, FancyEqMixin):

compareAttributes = ("_hostSalt", "_hostHash", "keyType", "publicKey", "comment")

def __init__(self, hostSalt, hostHash, keyType, publicKey, comment):
def __init__(
self,
hostSalt: bytes,
hostHash: bytes,
keyType: bytes,
publicKey: Key,
comment: bytes | None,
) -> None:
self._hostSalt = hostSalt
self._hostHash = hostHash
super().__init__(keyType, publicKey, comment)

@classmethod
def fromString(cls, string):
def fromString(cls, string: bytes) -> HashedEntry:
"""
Load a hashed entry from a string representing a line in a known_hosts
file.
@param string: A complete single line from a I{known_hosts} file,
formatted as defined by OpenSSH.
@type string: L{bytes}
@raise DecodeError: if the key, the hostname, or the is not valid
encoded as valid base64
@raise InvalidEntry: if the entry does not have the right number of
elements and is therefore invalid, or the host/hash portion contains
more items than just the host and hash.
elements and is therefore invalid, or the host/hash portion
contains more items than just the host and hash.
@raise BadKeyError: if the key, once decoded from base64, is not
actually an SSH key.
@return: The newly created L{HashedEntry} instance, initialized with the
information from C{string}.
@return: The newly created L{HashedEntry} instance, initialized with
the information from C{string}.
"""
stuff, keyType, key, comment = _extractCommon(string)
saltAndHash = stuff[len(cls.MAGIC) :].split(b"|")
Expand Down Expand Up @@ -346,7 +352,7 @@ class KnownHostsFile:
@ivar _savePath: See C{savePath} parameter of L{__init__}.
"""

def __init__(self, savePath):
def __init__(self, savePath: FilePath[str]) -> None:
"""
Create a new, empty KnownHostsFile.
Expand All @@ -356,12 +362,12 @@ def __init__(self, savePath):
@param savePath: The L{FilePath} to which to save new entries.
@type savePath: L{FilePath}
"""
self._added = []
self._added: list[IKnownHostEntry] = []
self._savePath = savePath
self._clobber = True

@property
def savePath(self):
def savePath(self) -> FilePath[str]:
"""
@see: C{savePath} parameter of L{__init__}
"""
Expand Down Expand Up @@ -431,7 +437,9 @@ def hasHostKey(self, hostname, key):
raise HostKeyChanged(entry, path, line)
return False

def verifyHostKey(self, ui, hostname, ip, key):
def verifyHostKey(
self, ui: ConsoleUI, hostname: bytes, ip: bytes, key: Key
) -> Deferred[bool]:
"""
Verify the given host key for the given IP and host, asking for
confirmation from, and notifying, the given UI about changes to this
Expand All @@ -453,20 +461,21 @@ def verifyHostKey(self, ui, hostname, ip, key):
"""
hhk = defer.execute(self.hasHostKey, hostname, key)

def gotHasKey(result):
def gotHasKey(result: bool) -> bool | Deferred[bool]:
if result:
if not self.hasHostKey(ip, key):
ui.warn(
"Warning: Permanently added the %s host key for "
"IP address '%s' to the list of known hosts."
% (key.type(), nativeString(ip))
addMessage = (
f"Warning: Permanently added the {key.type()} host key"
f" for IP address '{ip.decode()}' to the list of known"
" hosts.\n"
)
ui.warn(addMessage.encode("utf-8"))
self.addHostKey(ip, key)
self.save()
return result
else:

def promptResponse(response):
def promptResponse(response: bool) -> bool:
if response:
self.addHostKey(hostname, key)
self.addHostKey(ip, key)
Expand All @@ -475,7 +484,7 @@ def promptResponse(response):
else:
raise UserRejectedKey()

keytype = key.type()
keytype: str = key.type()

if keytype == "EC":
keytype = "ECDSA"
Expand All @@ -497,7 +506,7 @@ def promptResponse(response):

return hhk.addCallback(gotHasKey)

def addHostKey(self, hostname, key):
def addHostKey(self, hostname: bytes, key: Key) -> HashedEntry:
"""
Add a new L{HashedEntry} to the key database.
Expand All @@ -520,19 +529,15 @@ def addHostKey(self, hostname, key):
self._added.append(entry)
return entry

def save(self):
def save(self) -> None:
"""
Save this L{KnownHostsFile} to the path it was loaded from.
"""
p = self._savePath.parent()
if not p.isdir():
p.makedirs()

if self._clobber:
mode = "wb"
else:
mode = "ab"

mode: Literal["a", "w"] = "w" if self._clobber else "a"
with self._savePath.open(mode) as hostsFileObj:
if self._added:
hostsFileObj.write(
Expand All @@ -542,18 +547,16 @@ def save(self):
self._clobber = False

@classmethod
def fromPath(cls, path):
def fromPath(cls, path: FilePath[str]) -> KnownHostsFile:
"""
Create a new L{KnownHostsFile}, potentially reading existing known
hosts information from the given file.
@param path: A path object to use for both reading contents from and
later saving to. If no file exists at this path, it is not an
error; a L{KnownHostsFile} with no entries is returned.
@type path: L{FilePath}
@return: A L{KnownHostsFile} initialized with entries from C{path}.
@rtype: L{KnownHostsFile}
"""
knownHosts = cls(path)
knownHosts._clobber = False
Expand All @@ -566,7 +569,7 @@ class ConsoleUI:
console, to be used during key verification.
"""

def __init__(self, opener):
def __init__(self, opener: Callable[[], IO[bytes]]):
"""
@param opener: A no-argument callable which should open a console
binary-mode file-like object to be used for reading and writing.
Expand All @@ -576,7 +579,7 @@ def __init__(self, opener):
"""
self.opener = opener

def prompt(self, text):
def prompt(self, text: bytes) -> Deferred[bool]:
"""
Write the given text as a prompt to the console output, then read a
result from the console input.
Expand All @@ -598,20 +601,19 @@ def body(ignored):
answer = f.readline().strip().lower()
if answer == b"yes":
return True
elif answer == b"no":
elif answer in {b"no", b""}:
return False
else:
f.write(b"Please type 'yes' or 'no': ")

return d.addCallback(body)

def warn(self, text):
def warn(self, text: bytes) -> None:
"""
Notify the user (non-interactively) of the provided text, by writing it
to the console.
@param text: Some information the user is to be made aware of.
@type text: L{bytes}
"""
try:
with closing(self.opener()) as f:
Expand Down

0 comments on commit ac0d3a5

Please sign in to comment.