Skip to content

Commit

Permalink
[flow] Simplify Marshal_tools
Browse files Browse the repository at this point in the history
Summary:
Marshal_tools exists because Marshal.to_channel and .from_channel use OCaml's
buffered channels, but buffering can be problematic when channels are passed
between processes.

Reading a serialized value from a file descriptor means we need to know how much
to read. Marshal_tools solved this by adding a preamble with the size. However,
the Marshal code already adds a prefix with size information for exactly this
purpose. We can get the size by reading the header into a buffer, then using
Marshal.data_size.

This is all documented here: https://ocaml.org/manual/4.14/api/Marshal.html#VALheader_size

Reviewed By: SamChou19815

Differential Revision: D57147207

fbshipit-source-id: 6de467d6df84deb38b478e956b777f56fabe00fe
  • Loading branch information
samwgoldman authored and facebook-github-bot committed May 9, 2024
1 parent e71943f commit 66b873c
Show file tree
Hide file tree
Showing 18 changed files with 69 additions and 153 deletions.
4 changes: 2 additions & 2 deletions src/commands/commandConnectSimple.ml
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ let open_connection ~timeout ~client_handshake sockaddr =
Marshal.to_string (snd client_handshake) []
)
in
Marshal_tools.to_fd_with_preamble fd wire |> ignore;
Marshal_tools.to_fd fd wire |> ignore;
conn
)

Expand All @@ -102,7 +102,7 @@ let get_handshake ~timeout sockaddr ic oc =
SocketHandshake.(
try
let fd = Timeout.descr_of_in_channel ic in
let wire = (Marshal_tools.from_fd_with_preamble ~timeout fd : server_handshake_wire) in
let wire = (Marshal_tools.from_fd ~timeout fd : server_handshake_wire) in
let server_handshake =
( fst wire |> Hh_json.json_of_string |> json_to__monitor_to_client_1,
snd wire |> Base.Option.map ~f:(fun s : monitor_to_client_2 -> Marshal.from_string s 0)
Expand Down
4 changes: 2 additions & 2 deletions src/commands/commandUtils.ml
Original file line number Diff line number Diff line change
Expand Up @@ -1689,7 +1689,7 @@ let rec connect_and_make_request flowconfig_name =
command = cmd;
}
in
Marshal_tools.to_fd_with_preamble ?timeout (Unix.descr_of_out_channel oc) command |> ignore;
Marshal_tools.to_fd ?timeout (Unix.descr_of_out_channel oc) command |> ignore;
flush oc
in
let eprintf_with_spinner msg =
Expand All @@ -1704,7 +1704,7 @@ let rec connect_and_make_request flowconfig_name =
let rec wait_for_response ?timeout ~quiet ~emoji ~root (ic : Timeout.in_channel) =
let use_emoji = Tty.supports_emoji () && emoji in
let response : MonitorProt.monitor_to_client_message =
try Marshal_tools.from_fd_with_preamble ?timeout (Timeout.descr_of_in_channel ic) with
try Marshal_tools.from_fd ?timeout (Timeout.descr_of_in_channel ic) with
| Unix.Unix_error ((Unix.EPIPE | Unix.ECONNRESET), _, _) ->
if (not quiet) && Tty.spinner_used () then Tty.print_clear_line stderr;
raise End_of_file
Expand Down
6 changes: 3 additions & 3 deletions src/hack_forked/dfind/dfindLibLwt.ml
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,13 @@ let init log_fds (scuba_table, roots) =
let pid handle = handle.daemon_handle.Daemon.pid

let wait_until_ready handle =
let%lwt msg = Marshal_tools_lwt.from_fd_with_preamble handle.infd in
let%lwt msg = Marshal_tools_lwt.from_fd handle.infd in
assert (msg = DfindServer.Ready);
Lwt.return ()

let request_changes handle =
let%lwt _ = Marshal_tools_lwt.to_fd_with_preamble handle.outfd () in
Marshal_tools_lwt.from_fd_with_preamble handle.infd
let%lwt _ = Marshal_tools_lwt.to_fd handle.outfd () in
Marshal_tools_lwt.from_fd handle.infd

let get_changes handle =
let rec loop acc =
Expand Down
6 changes: 3 additions & 3 deletions src/hack_forked/dfind/dfindServer.ml
Original file line number Diff line number Diff line change
Expand Up @@ -58,18 +58,18 @@ let run_daemon (scuba_table, roots) (ic, oc) =
let env = DfindEnv.make roots in
Base.List.iter ~f:(DfindAddFile.path env) roots;
FlowEventLogger.dfind_ready scuba_table t;
Marshal_tools.to_fd_with_preamble outfd Ready |> ignore;
Marshal_tools.to_fd outfd Ready |> ignore;
ignore @@ Hh_logger.log_duration "Initialization" t;
let acc = ref SSet.empty in
let descr_in = Daemon.descr_of_in_channel ic in
let fsnotify_callback events =
acc := Base.List.fold events ~f:(process_fsnotify_event env) ~init:!acc
in
let message_in_callback () =
let () = Marshal_tools.from_fd_with_preamble infd in
let () = Marshal_tools.from_fd infd in
let count = SSet.cardinal !acc in
if count > 0 then Hh_logger.log "Sending %d file updates" count;
Marshal_tools.to_fd_with_preamble outfd (Updates !acc) |> ignore;
Marshal_tools.to_fd outfd (Updates !acc) |> ignore;
acc := SSet.empty
in
while true do
Expand Down
8 changes: 4 additions & 4 deletions src/hack_forked/procs/worker.ml
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ let worker_job_main infd outfd =

let len =
Measure.time "worker_send_response" (fun () ->
try Marshal_tools.to_fd_with_preamble ~flags:[Marshal.Closures] outfd (Some data) with
try Marshal_tools.to_fd ~flags:[Marshal.Closures] outfd (Some data) with
| Unix.Unix_error (Unix.EPIPE, _, _) -> raise Connection_closed
)
in
Expand All @@ -117,7 +117,7 @@ let worker_job_main infd outfd =

let stats = Measure.serialize (Measure.pop_global ()) in
let _ =
try Marshal_tools.to_fd_with_preamble outfd stats with
try Marshal_tools.to_fd outfd stats with
| Unix.Unix_error (Unix.EPIPE, _, _) -> raise Connection_closed
in
result_sent := true
Expand All @@ -126,7 +126,7 @@ let worker_job_main infd outfd =
Measure.push_global ();
let (Request do_process) =
Measure.time "worker_read_request" (fun () ->
try Marshal_tools.from_fd_with_preamble infd with
try Marshal_tools.from_fd infd with
| End_of_file -> raise Connection_closed
)
in
Expand All @@ -147,7 +147,7 @@ let worker_job_main infd outfd =
| WorkerCancel.Worker_should_cancel ->
(* Send `None` to reflect canceled status. *)
if not !result_sent then (
try ignore (Marshal_tools.to_fd_with_preamble outfd None) with
try ignore (Marshal_tools.to_fd outfd None) with
| Unix.Unix_error (Unix.EPIPE, _, _) -> raise Connection_closed
)
with
Expand Down
16 changes: 7 additions & 9 deletions src/hack_forked/procs/workerController.ml
Original file line number Diff line number Diff line change
Expand Up @@ -180,9 +180,7 @@ let send_non_blocking worker worker_pid infd outfd (f : 'a -> 'b) (x : 'a) : uni
Lwt.protected
( (* This write must happen first, to synchronize with the Canceled exception handler. *)
sent_request := true;
let%lwt _ =
Marshal_tools_lwt.to_fd_with_preamble ~flags:[Stdlib.Marshal.Closures] outfd request
in
let%lwt _ = Marshal_tools_lwt.to_fd ~flags:[Stdlib.Marshal.Closures] outfd request in
Lwt.return_unit
)
with
Expand All @@ -200,10 +198,10 @@ let send_non_blocking worker worker_pid infd outfd (f : 'a -> 'b) (x : 'a) : uni
(* We should not be canceled again at this point, but just in case prevent this operation
from being canceled. We will re-raise the Canceled exception anyway. *)
Lwt.no_cancel
(match%lwt Marshal_tools_lwt.from_fd_with_preamble infd with
(match%lwt Marshal_tools_lwt.from_fd infd with
| None -> Lwt.return_unit
| Some _ ->
let%lwt _ = Marshal_tools_lwt.from_fd_with_preamble infd in
let%lwt _ = Marshal_tools_lwt.from_fd infd in
Lwt.return_unit)
else
Lwt.return_unit
Expand Down Expand Up @@ -246,13 +244,13 @@ let read_non_blocking (type result) worker_pid infd : (result * Measure.record_d
Lwt.protected
( (* This write must happen first, to synchronize with the Canceled exception handler. *)
read_response := true;
let%lwt (data : result option) = Marshal_tools_lwt.from_fd_with_preamble infd in
let%lwt (data : result option) = Marshal_tools_lwt.from_fd infd in
match data with
| None ->
Lwt.wakeup signal_finished_read ();
Lwt.return_none
| Some data ->
let%lwt (stats : Measure.record_data) = Marshal_tools_lwt.from_fd_with_preamble infd in
let%lwt (stats : Measure.record_data) = Marshal_tools_lwt.from_fd infd in
Lwt.wakeup signal_finished_read ();
Lwt.return (Some (data, stats))
)
Expand All @@ -274,10 +272,10 @@ let read_non_blocking (type result) worker_pid infd : (result * Measure.record_d
(* We should not be canceled again at this point, but just in case prevent this operation
from being canceled. We will re-raise the Canceled exception anyway. *)
Lwt.no_cancel
(match%lwt Marshal_tools_lwt.from_fd_with_preamble infd with
(match%lwt Marshal_tools_lwt.from_fd infd with
| None -> Lwt.return_unit
| Some _ ->
let%lwt _ = Marshal_tools_lwt.from_fd_with_preamble infd in
let%lwt _ = Marshal_tools_lwt.from_fd infd in
Lwt.return_unit)
in
Exception.reraise exn
Expand Down
11 changes: 5 additions & 6 deletions src/hack_forked/utils/jsonrpc/jsonrpc.ml
Original file line number Diff line number Diff line change
Expand Up @@ -169,15 +169,15 @@ let internal_run_daemon' (oc : queue_message Daemon.out_channel) : unit =
| Hh_json.Syntax_error _ -> (true, Recoverable_exception edata)
| _ -> (false, Fatal_exception edata)
in
Marshal_tools.to_fd_with_preamble out_fd marshal |> ignore;
Marshal_tools.to_fd out_fd marshal |> ignore;
should_continue
end
| Write ->
assert (not (Queue.is_empty messages_to_send));
let timestamped_json = Queue.pop messages_to_send in
(* We can assume that the entire write will succeed, since otherwise
Marshal_tools.to_fd_with_preamble will throw an exception. *)
Marshal_tools.to_fd_with_preamble out_fd (Timestamped_json timestamped_json) |> ignore;
Marshal_tools.to_fd will throw an exception. *)
Marshal_tools.to_fd out_fd (Timestamped_json timestamped_json) |> ignore;
true
in
if should_continue then loop ()
Expand All @@ -196,8 +196,7 @@ let internal_run_daemon (_dummy_param : unit) (_ic, (oc : queue_message Daemon.o
let stack = Exception.get_full_backtrace_string 500 e in
(try
let out_fd = Daemon.descr_of_out_channel oc in
Marshal_tools.to_fd_with_preamble out_fd (Fatal_exception { Marshal_tools.message; stack })
|> ignore
Marshal_tools.to_fd out_fd (Fatal_exception { Marshal_tools.message; stack }) |> ignore
with
| _ ->
(* There may be a broken pipe, for example. We should just give up on
Expand Down Expand Up @@ -232,7 +231,7 @@ let get_read_fd (queue : queue) : Unix.file_descr = Lwt_unix.unix_file_descr que
let read_single_message_into_queue_wait (message_queue : queue) : queue_message Lwt.t =
let%lwt message =
try%lwt
let%lwt message = Marshal_tools_lwt.from_fd_with_preamble message_queue.daemon_in_fd in
let%lwt message = Marshal_tools_lwt.from_fd message_queue.daemon_in_fd in
Lwt.return message
with
| End_of_file as e ->
Expand Down
100 changes: 21 additions & 79 deletions src/hack_forked/utils/marshal_tools/marshal_tools.ml
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,11 @@
* libancillary) to start reading the next object.
*
* The solution:
* Start each message with a fixed-size preamble that describes the
* size of the payload to read. Read precisely that many bytes directly
* from the FD avoiding Ocaml channels entirely.
* Use unbuffered IO directly with Unix file descriptors. We can read the size
* of the serialized value by first reading the header then extracting the
* size, using the Marshal.data_size API.
*)

exception Payload_Size_Too_Large_Exception

exception Malformed_Preamble_Exception

exception Writing_Preamble_Exception

exception Writing_Payload_Exception

(* We want to marshal exceptions (or at least their message+stacktrace) over *)
Expand Down Expand Up @@ -88,7 +82,7 @@ module RegularWriterReader : REGULAR_WRITER_READER = struct
*
* People using Marshal_tools probably are calling Unix.select first. However that only guarantees
* that the first read won't block. Marshal_tools will always do at least 2 reads (one for the
* preamble and one or more for the data). Any read after the first might block.
* header and one or more for the data). Any read after the first might block.
*)
let rec read ?timeout fd ~buffer ~offset ~size =
match Timeout.select ?timeout [fd] [] [] ~-.1.0 with
Expand All @@ -103,66 +97,17 @@ module RegularWriterReader : REGULAR_WRITER_READER = struct
end

module MarshalToolsFunctor (WriterReader : WRITER_READER) : sig
val expected_preamble_size : int

val to_fd_with_preamble :
val to_fd :
?timeout:Timeout.t ->
?flags:Marshal.extern_flags list ->
WriterReader.fd ->
'a ->
int WriterReader.result

val from_fd_with_preamble : ?timeout:Timeout.t -> WriterReader.fd -> 'a WriterReader.result
val from_fd : ?timeout:Timeout.t -> WriterReader.fd -> 'a WriterReader.result
end = struct
let ( >>= ) = WriterReader.( >>= )

let preamble_start_sentinel = '\142'

(* Size in bytes. *)
let preamble_core_size = 4

let expected_preamble_size = preamble_core_size + 1

(* Payload size in bytes = 2^31 - 1. *)
let maximum_payload_size = (1 lsl (preamble_core_size * 8)) - 1

let get_preamble_core (size : int) =
(* We limit payload size to 2^31 - 1 bytes. *)
if size >= maximum_payload_size then raise Payload_Size_Too_Large_Exception;
let rec loop i (remainder : int) acc =
if i < 0 then
acc
else
loop
(i - 1)
(remainder / 256)
( Bytes.set acc i (Char.chr (remainder mod 256));
acc
)
in
loop (preamble_core_size - 1) size (Bytes.create preamble_core_size)

let make_preamble (size : int) =
let preamble_core = get_preamble_core size in
let preamble = Bytes.create (preamble_core_size + 1) in
Bytes.set preamble 0 preamble_start_sentinel;
Bytes.blit preamble_core 0 preamble 1 4;
preamble

let parse_preamble preamble =
if
Bytes.length preamble <> expected_preamble_size
|| Bytes.get preamble 0 <> preamble_start_sentinel
then
raise Malformed_Preamble_Exception;
let rec loop i acc =
if i >= 5 then
acc
else
loop (i + 1) ((acc * 256) + int_of_char (Bytes.get preamble i))
in
loop 1 0

let rec write_payload ?timeout fd buffer offset to_write =
if to_write = 0 then
WriterReader.return offset
Expand All @@ -174,19 +119,14 @@ end = struct
write_payload ?timeout fd buffer (offset + bytes_written) (to_write - bytes_written)

(* Returns the size of the marshaled payload *)
let to_fd_with_preamble ?timeout ?(flags = []) fd obj =
let to_fd ?timeout ?(flags = []) fd obj =
let payload = Marshal.to_bytes obj flags in
let size = Bytes.length payload in
let preamble = make_preamble size in
write_payload ?timeout fd preamble 0 expected_preamble_size >>= fun preamble_bytes_written ->
if preamble_bytes_written <> expected_preamble_size then
raise Writing_Preamble_Exception
write_payload ?timeout fd payload 0 size >>= fun bytes_written ->
if bytes_written <> size then
raise Writing_Payload_Exception
else
write_payload ?timeout fd payload 0 size >>= fun bytes_written ->
if bytes_written <> size then
raise Writing_Payload_Exception
else
WriterReader.return size
WriterReader.return size

let rec read_payload ?timeout fd buffer offset to_read =
if to_read = 0 then
Expand All @@ -198,16 +138,18 @@ end = struct
else
read_payload ?timeout fd buffer (offset + bytes_read) (to_read - bytes_read)

let from_fd_with_preamble ?timeout fd =
let preamble = Bytes.create expected_preamble_size in
read_payload ?timeout fd preamble 0 expected_preamble_size >>= fun bytes_read ->
if bytes_read <> expected_preamble_size then
let from_fd ?timeout fd =
let header = Bytes.create Marshal.header_size in
read_payload ?timeout fd header 0 Marshal.header_size >>= fun bytes_read ->
if bytes_read <> Marshal.header_size then
raise End_of_file
else
let payload_size = parse_preamble preamble in
let payload = Bytes.create payload_size in
read_payload ?timeout fd payload 0 payload_size >>= fun payload_size_read ->
if payload_size_read <> payload_size then
let data_size = Marshal.data_size header 0 in
let payload = Bytes.create (Marshal.header_size + data_size) in
Bytes.unsafe_blit header 0 payload 0 Marshal.header_size;
read_payload ?timeout fd payload Marshal.header_size data_size >>= fun offset_after_read ->
let data_size_read = offset_after_read - Marshal.header_size in
if data_size_read <> data_size then
raise End_of_file
else
WriterReader.return (Marshal.from_bytes payload 0)
Expand Down
17 changes: 4 additions & 13 deletions src/hack_forked/utils/marshal_tools/marshal_tools.mli
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,16 @@
* LICENSE file in the root directory of this source tree.
*)

exception Payload_Size_Too_Large_Exception

exception Malformed_Preamble_Exception

exception Writing_Preamble_Exception

exception Writing_Payload_Exception

type remote_exception_data = {
message: string;
stack: string;
}

val to_fd_with_preamble :
?timeout:Timeout.t -> ?flags:Marshal.extern_flags list -> Unix.file_descr -> 'a -> int
val to_fd : ?timeout:Timeout.t -> ?flags:Marshal.extern_flags list -> Unix.file_descr -> 'a -> int

val from_fd_with_preamble : ?timeout:Timeout.t -> Unix.file_descr -> 'a
val from_fd : ?timeout:Timeout.t -> Unix.file_descr -> 'a

module type WRITER_READER = sig
type 'a result
Expand All @@ -38,14 +31,12 @@ module type WRITER_READER = sig
end

module MarshalToolsFunctor (WriterReader : WRITER_READER) : sig
val expected_preamble_size : int

val to_fd_with_preamble :
val to_fd :
?timeout:Timeout.t ->
?flags:Marshal.extern_flags list ->
WriterReader.fd ->
'a ->
int WriterReader.result

val from_fd_with_preamble : ?timeout:Timeout.t -> WriterReader.fd -> 'a WriterReader.result
val from_fd : ?timeout:Timeout.t -> WriterReader.fd -> 'a WriterReader.result
end

0 comments on commit 66b873c

Please sign in to comment.