Blob Blame History Raw
From 2440a8b69a118fe14e73eb6cab4a050922866f1a Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Edwin=20T=C3=B6r=C3=B6k?= <edvin.torok@citrix.com>
Date: Wed, 12 Oct 2022 19:13:03 +0100
Subject: tools/ocaml/xb: Add BoundedQueue
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

Ensures we cannot store more than [capacity] elements in a [Queue].  Replacing
all Queue with this module will then ensure at compile time that all Queues
are correctly bound checked.

Each element in the queue has a class with its own limits.  This, in a
subsequent change, will ensure that command responses can proceed during a
flood of watch events.

No functional change.

This is part of XSA-326.

Reported-by: Julien Grall <jgrall@amazon.com>
Signed-off-by: Edwin Török <edvin.torok@citrix.com>
Acked-by: Christian Lindig <christian.lindig@citrix.com>

diff --git a/tools/ocaml/libs/xb/xb.ml b/tools/ocaml/libs/xb/xb.ml
index 165fd4a1edf4..4197a3888a68 100644
--- a/tools/ocaml/libs/xb/xb.ml
+++ b/tools/ocaml/libs/xb/xb.ml
@@ -17,6 +17,98 @@
 module Op = struct include Op end
 module Packet = struct include Packet end
 
+module BoundedQueue : sig
+	type ('a, 'b) t
+
+	(** [create ~capacity ~classify ~limit] creates a queue with maximum [capacity] elements.
+	    This is burst capacity, each element is further classified according to [classify],
+	    and each class can have its own [limit].
+	    [capacity] is enforced as an overall limit.
+	    The [limit] can be dynamic, and can be smaller than the number of elements already queued of that class,
+	    in which case those elements are considered to use "burst capacity".
+	  *)
+	val create: capacity:int -> classify:('a -> 'b) -> limit:('b -> int) -> ('a, 'b) t
+
+	(** [clear q] discards all elements from [q] *)
+	val clear: ('a, 'b) t -> unit
+
+	(** [can_push q] when [length q < capacity].	*)
+	val can_push: ('a, 'b) t -> 'b -> bool
+
+	(** [push e q] adds [e] at the end of queue [q] if [can_push q], or returns [None]. *)
+	val push: 'a -> ('a, 'b) t -> unit option
+
+	(** [pop q] removes and returns first element in [q], or raises [Queue.Empty]. *)
+	val pop: ('a, 'b) t -> 'a
+
+	(** [peek q] returns the first element in [q], or raises [Queue.Empty].  *)
+	val peek : ('a, 'b) t -> 'a
+
+	(** [length q] returns the current number of elements in [q] *)
+	val length: ('a, 'b) t -> int
+
+	(** [debug string_of_class q] prints queue usage statistics in an unspecified internal format. *)
+	val debug: ('b -> string) -> (_, 'b) t -> string
+end = struct
+	type ('a, 'b) t =
+		{ q: 'a Queue.t
+		; capacity: int
+		; classify: 'a -> 'b
+		; limit: 'b -> int
+		; class_count: ('b, int) Hashtbl.t
+		}
+
+	let create ~capacity ~classify ~limit =
+		{ capacity; q = Queue.create (); classify; limit; class_count = Hashtbl.create 3 }
+
+	let get_count t classification = try Hashtbl.find t.class_count classification with Not_found -> 0
+
+	let can_push_internal t classification class_count =
+		Queue.length t.q < t.capacity && class_count < t.limit classification
+
+	let ok = Some ()
+
+	let push e t =
+		let classification = t.classify e in
+		let class_count = get_count t classification in
+		if can_push_internal t classification class_count then begin
+			Queue.push e t.q;
+			Hashtbl.replace t.class_count classification (class_count + 1);
+			ok
+		end
+		else
+			None
+
+	let can_push t classification =
+		can_push_internal t classification @@ get_count t classification
+
+	let clear t =
+		Queue.clear t.q;
+		Hashtbl.reset t.class_count
+
+	let pop t =
+		let e = Queue.pop t.q in
+		let classification = t.classify e in
+		let () = match get_count t classification - 1 with
+		| 0 -> Hashtbl.remove t.class_count classification (* reduces memusage *)
+		| n -> Hashtbl.replace t.class_count classification n
+		in
+		e
+
+	let peek t = Queue.peek t.q
+	let length t = Queue.length t.q
+
+	let debug string_of_class t =
+		let b = Buffer.create 128 in
+		Printf.bprintf b "BoundedQueue capacity: %d, used: {" t.capacity;
+		Hashtbl.iter (fun packet_class count ->
+			Printf.bprintf b "	%s: %d" (string_of_class packet_class) count
+		) t.class_count;
+		Printf.bprintf b "}";
+		Buffer.contents b
+end
+
+
 exception End_of_file
 exception Eagain
 exception Noent