Reservoir Sampling

Reservoir sampling is a great random sampling algorithm every data engineer should know. It’s an algorithm for extracting a random sample of a specified size, over a large unbounded dataset. The data is such that you cannot pull it all in memory at once, and you don’t know how large it will be when taking the sample.

Over on wikipedia you can check out a nice explanation of the idea behind reservoir sampling. It presents the most simple algorithm of reservoir sampling. Let me regurgitate it in C#:

public static T[] TakeSample<T>(IEnumerable<T> input, int sampleCount) {
	if (sampleCount <= 0)
		throw new ArgumentOutOfRangeException ("sampleCount");

	Random rand = new Random ();
	var samples = new T[sampleCount];
	int i = 0;
	foreach (T item in input) {
		if (i < sampleCount) {
			samples [i] = item;
		} else {
			int j = rand.Next () % i;
			if (j < sampleCount)
				samples [j] = item;
		}
		i++;
	}

	if (i < sampleCount)
		throw new InvalidOperationException ("Input data does not have enough items to sample");

	return samples;
}

To some this algorithm may seem botched, as the last items have the least likelihood of being selected. However the overall likelihood of being selected is roughly distributed evenly across the entire dataset. This is because the items with more likelihood of being selected earlier on (or absolute certainty for the first k samples, where k = sample count) will have more likelihood of being replaced by a successor sample.

Jeffrey Scott Vitter has published some alternative implementations that optimise the performance of the sampling.

Distributed Reservoir Sampling

When you are dealing with BIG data where you have a distribute infrastructure at your disposal, your not going to stream the entire data set into a single sequence. Here is a crafty map-reduce approach to the sampling problem on stack overflow: http://stackoverflow.com/questions/2514061/how-to-pick-random-small-data-samples-using-map-reduce. I really like this approach as it draws on fundamental computer science algorithm know-how with big-data technology to get the job done!

 

 

Augmented Interval Tree in C#

Interval trees are an efficient ADT for storing and searching intervals. It is an ADT that probably doesn’t make the top 10 most commonly used collections in computer science. I often see code where a list/array like structure is used to store and searching interval data. Sometimes that is fine – even preferred for the sake of simplicity – but only in cases where the code is rarely run and is not having to handle large volumes of data.

My C# implementation of the IntervalTree uses generics and has the following features:

  • It uses generics.
  • It is mutable.
  • It is backed by a self balancing BST (AVL).
  • It is xml and binary serializable.
  • It supports duplicate intervals.

Originally I wrote the tree so that the end point selector was simply a lamba. However lamba’s are not serializable, so I changed the selector to become an interface (and implemented my own Xml serialization methods to handle the interface).

If your interval collections can afford to be immutable, then a better performing solution is to use a centered interval tree instead. Although the centered interval tree accomplishes the same average complexity has the augmented BST, the tree construction is generally faster.

Here is the code:

using System;
using System.Collections;
using System.Collections.Generic;
using System.Globalization;
using System.Linq;
using System.Runtime.Serialization;
using System.Xml;
using System.Xml.Schema;
using System.Xml.Serialization;

namespace BrookNovak.Collections
{
	/// <summary>
	/// An interval tree that supports duplicate entries.
	/// </summary>
	/// <typeparam name="TInterval">The interval type</typeparam>
	/// <typeparam name="TPoint">The interval's start and end type</typeparam>
	/// <remarks>
	/// This interval tree is implemented as a balanced augmented AVL tree.
	/// Modifications are O(log n) typical case.
	/// Searches are O(log n) typical case.
	/// </remarks>
	[Serializable]
	public class IntervalTree<TInterval, TPoint> : ICollection<TInterval>, ICollection, ISerializable, IXmlSerializable 
		where TPoint : IComparable<TPoint>
	{
		private readonly object syncRoot;
		private IntervalNode root;
		private ulong modifications;
		private IIntervalSelector<TInterval, TPoint> intervalSelector;

		/// <summary>
		/// Default ctor required for XML serialization support
		/// </summary>
		private IntervalTree()
		{
			syncRoot = new object();
		}

		public IntervalTree(IEnumerable<TInterval> intervals, IIntervalSelector<TInterval, TPoint> intervalSelector) :
			this(intervalSelector)
		{
			AddRange(intervals);
		}

		public IntervalTree(IIntervalSelector<TInterval, TPoint> intervalSelector)
			: this()
		{
			if (intervalSelector == null)
				throw new ArgumentNullException("intervalSelector");
			this.intervalSelector = intervalSelector;
		}

		/// <summary>
		/// Returns the maximum end point in the entire collection.
		/// </summary>
		public TPoint MaxEndPoint
		{
			get
			{
				if(root == null)
					throw new InvalidOperationException("Cannot determine max end point for emtpy interval tree");
				return root.MaxEndPoint;
			}
		}

		#region Binary Serialization

		public IntervalTree(SerializationInfo info, StreamingContext context)
			: this()
		{
			// Reset the property value using the GetValue method.
			var intervals = (TInterval[])info.GetValue("intervals", typeof(TInterval[]));
			intervalSelector = (IIntervalSelector<TInterval, TPoint>)info.GetValue("selector", typeof(IIntervalSelector<TInterval, TPoint>));
			AddRange(intervals);
		}

		public void GetObjectData(SerializationInfo info, StreamingContext context)
		{
			var intervals = new TInterval[Count];
			CopyTo(intervals, 0);
			info.AddValue("intervals", intervals, typeof(TInterval[]));
			info.AddValue("selector", intervalSelector, typeof(IIntervalSelector<TInterval, TPoint>));
		}

		#endregion

		#region IXmlSerializable

		public void WriteXml(XmlWriter writer)
		{
			writer.WriteStartElement("Intervals");
			writer.WriteAttributeString("Count", Count.ToString(CultureInfo.InvariantCulture));
			var itemSerializer = new XmlSerializer(typeof(TInterval));
			foreach (var item in this)
			{
				itemSerializer.Serialize(writer, item);
			}
			writer.WriteEndElement();

			writer.WriteStartElement("Selector");
			var typeName = intervalSelector.GetType().AssemblyQualifiedName ?? intervalSelector.GetType().FullName;
			writer.WriteAttributeString("Type", typeName);
			var selectorSerializer = new XmlSerializer(intervalSelector.GetType());
			selectorSerializer.Serialize(writer, intervalSelector);
			writer.WriteEndElement();
		}

		public void ReadXml(XmlReader reader)
		{
			reader.MoveToContent();
			reader.ReadStartElement();

			reader.MoveToAttribute("Count");
			int count = int.Parse(reader.Value);
			reader.MoveToElement();

			if (count > 0 && reader.IsEmptyElement)
				throw new FormatException("Missing tree items");
			if (count == 0 && !reader.IsEmptyElement)
				throw new FormatException("Unexpected content in tree item Xml (expected empty content)");

			reader.ReadStartElement("Intervals");

			var items = new TInterval[count];

			if (count > 0)
			{
				var itemSerializer = new XmlSerializer(typeof(TInterval));

				for (int i = 0; i < count; i++)
				{
					items[i] = (TInterval)itemSerializer.Deserialize(reader);

				}
				reader.ReadEndElement(); // </intervals>
			}

			reader.MoveToAttribute("Type");
			string selectorTypeFullName = reader.Value;
			if (string.IsNullOrEmpty(selectorTypeFullName))
				throw new FormatException("Selector type name missing");
			reader.MoveToElement();

			reader.ReadStartElement("Selector");

			var selectorType = Type.GetType(selectorTypeFullName);
			if(selectorType == null)
				throw new XmlException(string.Format("Selector type {0} missing from loaded assemblies", selectorTypeFullName));
			var selectorSerializer = new XmlSerializer(selectorType);
			intervalSelector = (IIntervalSelector<TInterval, TPoint>)selectorSerializer.Deserialize(reader);

			reader.ReadEndElement(); // </selector>

			AddRange(items);
		}

		public XmlSchema GetSchema()
		{
			return null;
		}

		#endregion

		#region IEnumerable, IEnumerable<T>

		IEnumerator IEnumerable.GetEnumerator()
		{
			return new IntervalTreeEnumerator(this);
		}

		public IEnumerator<TInterval> GetEnumerator()
		{
			return new IntervalTreeEnumerator(this);
		}

		#endregion

		#region ICollection

		public bool IsSynchronized { get { return false; } }

		public Object SyncRoot { get { return syncRoot; } }

		public void CopyTo(
			Array array,
			int arrayIndex)
		{
			if (array == null)
				throw new ArgumentNullException("array");
			PerformCopy(arrayIndex, array.Length, (i, v) => array.SetValue(v, i));
		}

		#endregion

		#region ICollection<T>

		public int Count { get; private set; }

		public bool IsReadOnly
		{
			get
			{
				return false;
			}
		}

		public void CopyTo(
			TInterval[] array,
			int arrayIndex)
		{
			if (array == null)
				throw new ArgumentNullException("array");
			PerformCopy(arrayIndex, array.Length, (i, v) => array[i] = v);
		}

		/// <summary>
		/// Tests if an item is contained in the tree.
		/// </summary>
		/// <param name="item">The item to check</param>
		/// <returns>
		/// True iff the item exists in the collection. 
		/// </returns>
		/// <remarks>
		/// This method uses the collection’s objects’ Equals and CompareTo methods on item to determine whether item exists.
		/// </remarks>
		public bool Contains(TInterval item)
		{
			if (ReferenceEquals(item, null))
				throw new ArgumentNullException("item");

			return FindMatchingNodes(item).Any();
		}

		public void Clear()
		{
			SetRoot(null);
			Count = 0;
			modifications++;
		}

		public void Add(TInterval item)
		{
			if (ReferenceEquals(item, null))
				throw new ArgumentNullException("item");

			var newNode = new IntervalNode(item, Start(item), End(item));

			if (root == null)
			{
				SetRoot(newNode);
				Count = 1;
				modifications++;
				return;
			}

			IntervalNode node = root;
			while (true)
			{
				var startCmp = newNode.Start.CompareTo(node.Start);
				if (startCmp <= 0)
				{
					if (startCmp == 0 && ReferenceEquals(node.Data, newNode.Data))
						throw new InvalidOperationException("Cannot add the same item twice (object reference already exists in db)");

					if (node.Left == null)
					{
						node.Left = newNode;
						break;
					}
					node = node.Left;
				}
				else
				{
					if (node.Right == null)
					{
						node.Right = newNode;
						break;
					}
					node = node.Right;
				}
			}

			modifications++;
			Count++;

			// Restructure tree to be balanced
			node = newNode;
			while (node != null)
			{
				node.UpdateHeight();
				node.UpdateMaxEndPoint();
				Rebalance(node);
				node = node.Parent;
			}
		}

		/// <summary>
		/// Removes an item.
		/// </summary>
		/// <param name="item">The item to remove</param>
		/// <returns>True if an item was removed</returns>
		/// <remarks>
		/// This method uses the collection’s objects’ Equals and CompareTo methods on item to retrieve the existing item.
		/// If there are duplicates of the item, then object reference is used to remove.
		/// If <see cref="TInterval"/> is not a reference type, then the first found equal interval will be removed.
		/// </remarks>
		public bool Remove(TInterval item)
		{
			if (ReferenceEquals(item, null))
				throw new ArgumentNullException("item");

			if (root == null)
				return false;

			var candidates = FindMatchingNodes(item).ToList();

			if (candidates.Count == 0)
				return false;

			IntervalNode toBeRemoved;
			if (candidates.Count == 1)
			{
				toBeRemoved = candidates[0];
			}
			else
			{
				toBeRemoved = candidates.SingleOrDefault(x => ReferenceEquals(x.Data, item)) ?? candidates[0];
			}

			var parent = toBeRemoved.Parent;
			var isLeftChild = toBeRemoved.IsLeftChild;

			if (toBeRemoved.Left == null && toBeRemoved.Right == null)
			{
				if (parent != null)
				{
					if (isLeftChild)
						parent.Left = null;
					else
						parent.Right = null;

					Rebalance(parent);
				}
				else
				{
					SetRoot(null);
				}
			}
			else if (toBeRemoved.Right == null)
			{
				if (parent != null)
				{
					if (isLeftChild)
						parent.Left = toBeRemoved.Left;
					else
						parent.Right = toBeRemoved.Left;

					Rebalance(parent);
				}
				else
				{
					SetRoot(toBeRemoved.Left);
				}
			}
			else if (toBeRemoved.Left == null)
			{
				if (parent != null)
				{
					if (isLeftChild)
						parent.Left = toBeRemoved.Right;
					else
						parent.Right = toBeRemoved.Right;

					Rebalance(parent);
				}
				else
				{
					SetRoot(toBeRemoved.Right);
				}
			}
			else
			{
				IntervalNode replacement, replacementParent, temp;

				if (toBeRemoved.Balance > 0)
				{
					if (toBeRemoved.Left.Right == null)
					{
						replacement = toBeRemoved.Left;
						replacement.Right = toBeRemoved.Right;
						temp = replacement;
					}
					else
					{
						replacement = toBeRemoved.Left.Right;
						while (replacement.Right != null)
						{
							replacement = replacement.Right;
						}
						replacementParent = replacement.Parent;
						replacementParent.Right = replacement.Left;

						temp = replacementParent;

						replacement.Left = toBeRemoved.Left;
						replacement.Right = toBeRemoved.Right;
					}
				}
				else
				{
					if (toBeRemoved.Right.Left == null)
					{
						replacement = toBeRemoved.Right;
						replacement.Left = toBeRemoved.Left;
						temp = replacement;
					}
					else
					{
						replacement = toBeRemoved.Right.Left;
						while (replacement.Left != null)
						{
							replacement = replacement.Left;
						}
						replacementParent = replacement.Parent;
						replacementParent.Left = replacement.Right;

						temp = replacementParent;

						replacement.Left = toBeRemoved.Left;
						replacement.Right = toBeRemoved.Right;
					}
				}

				if (parent != null)
				{
					if (isLeftChild)
						parent.Left = replacement;
					else
						parent.Right = replacement;
				}
				else
				{
					SetRoot(replacement);
				}

				Rebalance(temp);
			}

			toBeRemoved.Parent = null;
			Count--;
			modifications++;
			return true;
		}

		#endregion

		#region Public methods

		public void AddRange(IEnumerable<TInterval> intervals)
		{
			if (intervals == null)
				throw new ArgumentNullException("intervals");
			foreach (var interval in intervals)
			{
				Add(interval);
			}
		}
		
		public TInterval[] this[TPoint point]
		{
			get
			{
				return FindAt(point);
			}
		}

		public TInterval[] FindAt(TPoint point)
		{
			if (ReferenceEquals(point, null))
				throw new ArgumentNullException("point");

			var found = new List<IntervalNode>();
			PerformStabbingQuery(root, point, found);
			return found.Select(node => node.Data).ToArray();
		}

		public bool ContainsPoint(TPoint point)
		{
			return FindAt(point).Any();
		}

		public bool ContainsOverlappingInterval(TInterval item)
		{
			if (ReferenceEquals(item, null))
				throw new ArgumentNullException("item");

			return PerformStabbingQuery(root, item).Count > 0;
		}

		public TInterval[] FindOverlapping(TInterval item)
		{
			if (ReferenceEquals(item, null))
				throw new ArgumentNullException("item");

			return PerformStabbingQuery(root, item).Select(node => node.Data).ToArray();
		}

		#endregion

		#region Private methods

		private void PerformCopy(int arrayIndex, int arrayLength, Action<int, TInterval> setAtIndexDelegate)
		{
			if (arrayIndex < 0)
				throw new ArgumentOutOfRangeException("arrayIndex");
			int i = arrayIndex;
			IEnumerator<TInterval> enumerator = GetEnumerator();
			while (enumerator.MoveNext())
			{
				if (i >= arrayLength)
					throw new ArgumentOutOfRangeException("arrayIndex", "Not enough elements in array to copy content into");
				setAtIndexDelegate(i, enumerator.Current);
				i++;
			}
		}

		private IEnumerable<IntervalNode> FindMatchingNodes(TInterval interval)
		{
			return PerformStabbingQuery(root, interval).Where(node => node.Data.Equals(interval));
		}

		private void SetRoot(IntervalNode node)
		{
			root = node;
			if (root != null)
				root.Parent = null;
		}

		private TPoint Start(TInterval interval)
		{
			return intervalSelector.GetStart(interval);
		}

		private TPoint End(TInterval interval)
		{
			return intervalSelector.GetEnd(interval);
		}

		private bool DoesIntervalContain(TInterval interval, TPoint point)
		{
			return point.CompareTo(Start(interval)) >= 0
				&& point.CompareTo(End(interval)) <= 0;
		}

		private bool DoIntervalsOverlap(TInterval interval, TInterval other)
		{
			return Start(interval).CompareTo(End(other)) <= 0 &&
				End(interval).CompareTo(Start(other)) >= 0;
		}

		private void PerformStabbingQuery(IntervalNode node, TPoint point, List<IntervalNode> result)
		{
			if (node == null)
				return;

			if (point.CompareTo(node.MaxEndPoint) > 0)
				return;

			if (node.Left != null)
				PerformStabbingQuery(node.Left, point, result);

			if (DoesIntervalContain(node.Data, point))
				result.Add(node);

			if (point.CompareTo(node.Start) < 0)
				return;

			if (node.Right != null)
				PerformStabbingQuery(node.Right, point, result);
		}

		private List<IntervalNode> PerformStabbingQuery(IntervalNode node, TInterval interval)
		{
			var result = new List<IntervalNode>();
			PerformStabbingQuery(node, interval, result);
			return result;
		}

		private void PerformStabbingQuery(IntervalNode node, TInterval interval, List<IntervalNode> result)
		{
			if (node == null)
				return;

			if (Start(interval).CompareTo(node.MaxEndPoint) > 0)
				return;

			if (node.Left != null)
				PerformStabbingQuery(node.Left, interval, result);

			if (DoIntervalsOverlap(node.Data, interval))
				result.Add(node);

			if (End(interval).CompareTo(node.Start) < 0)
				return;

			if (node.Right != null)
				PerformStabbingQuery(node.Right, interval, result);
		}

		private void Rebalance(IntervalNode node)
		{
			if (node.Balance > 1)
			{
				if (node.Left.Balance < 0)
					RotateLeft(node.Left);
				RotateRight(node);
			}
			else if (node.Balance < -1)
			{
				if (node.Right.Balance > 0)
					RotateRight(node.Right);
				RotateLeft(node);
			}
		}

		private void RotateLeft(IntervalNode node)
		{
			var parent = node.Parent;
			var isNodeLeftChild = node.IsLeftChild;

			// Make node.Right the new root of this sub tree (instead of node)
			var pivotNode = node.Right;
			node.Right = pivotNode.Left;
			pivotNode.Left = node;

			if (parent != null)
			{
				if (isNodeLeftChild)
					parent.Left = pivotNode;
				else
					parent.Right = pivotNode;
			}
			else
			{
				SetRoot(pivotNode);
			}
		}

		private void RotateRight(IntervalNode node)
		{
			var parent = node.Parent;
			var isNodeLeftChild = node.IsLeftChild;

			// Make node.Left the new root of this sub tree (instead of node)
			var pivotNode = node.Left;
			node.Left = pivotNode.Right;
			pivotNode.Right = node;

			if (parent != null)
			{
				if (isNodeLeftChild)
					parent.Left = pivotNode;
				else
					parent.Right = pivotNode;
			}
			else
			{
				SetRoot(pivotNode);
			}
		}

		#endregion

		#region Inner classes

		[Serializable]
		private class IntervalNode
		{
			private IntervalNode left;
			private IntervalNode right;
			public IntervalNode Parent { get; set; }
			public TPoint Start { get; private set; }
			private TPoint End { get; set; }
			public TInterval Data { get; private set; }
			private int Height { get; set; }
			public TPoint MaxEndPoint { get; private set; }

			public IntervalNode(TInterval data, TPoint start, TPoint end)
			{
				if(start.CompareTo(end) > 0)
					throw new ArgumentOutOfRangeException("end", "The suplied interval has an invalid range, where start is greater than end");
				Data = data;
				Start = start;
				End = end;
				UpdateMaxEndPoint();
			}

			public IntervalNode Left
			{
				get
				{
					return left;
				}
				set
				{
					left = value;
					if (left != null)
						left.Parent = this;
					UpdateHeight();
					UpdateMaxEndPoint();
				}
			}

			public IntervalNode Right
			{
				get
				{
					return right;
				}
				set
				{
					right = value;
					if (right != null)
						right.Parent = this;
					UpdateHeight();
					UpdateMaxEndPoint();
				}
			}

			public int Balance
			{
				get
				{
					if (Left != null && Right != null)
						return Left.Height - Right.Height;
					if (Left != null)
						return Left.Height + 1;
					if (Right != null)
						return -(Right.Height + 1);
					return 0;
				}
			}

			public bool IsLeftChild
			{
				get
				{
					return Parent != null && Parent.Left == this;
				}
			}

			public void UpdateHeight()
			{
				if (Left != null && Right != null)
					Height = Math.Max(Left.Height, Right.Height) + 1;
				else if (Left != null)
					Height = Left.Height + 1;
				else if (Right != null)
					Height = Right.Height + 1;
				else
					Height = 0;
			}

			private static TPoint Max(TPoint comp1, TPoint comp2)
			{
				if (comp1.CompareTo(comp2) > 0)
					return comp1;
				return comp2;
			}

			public void UpdateMaxEndPoint()
			{
				TPoint max = End;
				if (Left != null)
					max = Max(max, Left.MaxEndPoint);
				if (Right != null)
					max = Max(max, Right.MaxEndPoint);
				MaxEndPoint = max;
			}

			public override string ToString()
			{
				return string.Format("[{0},{1}], maxEnd={2}", Start, End, MaxEndPoint);
			}
		}

		private class IntervalTreeEnumerator : IEnumerator<TInterval>
		{
			private readonly ulong modificationsAtCreation;
			private readonly IntervalTree<TInterval, TPoint> tree;
			private readonly IntervalNode startNode;
			private IntervalNode current;
			private bool hasVisitedCurrent;
			private bool hasVisitedRight;

			public IntervalTreeEnumerator(IntervalTree<TInterval, TPoint> tree)
			{
				this.tree = tree;
				modificationsAtCreation = tree.modifications;
				startNode = GetLeftMostDescendantOrSelf(tree.root);
				Reset();
			}

			public TInterval Current
			{
				get
				{
					if (current == null)
						throw new InvalidOperationException("Enumeration has finished.");

					if (ReferenceEquals(current, startNode) && !hasVisitedCurrent)
						throw new InvalidOperationException("Enumeration has not started.");

					return current.Data;
				}
			}

			object IEnumerator.Current
			{
				get
				{
					return Current;
				}
			}

			public void Reset()
			{
				if (modificationsAtCreation != tree.modifications)
					throw new InvalidOperationException("Collection was modified.");
				current = startNode;
				hasVisitedCurrent = false;
				hasVisitedRight = false;
			}

			public bool MoveNext()
			{
				if (modificationsAtCreation != tree.modifications)
					throw new InvalidOperationException("Collection was modified.");

				if (tree.root == null)
					return false;

				// Visit this node
				if (!hasVisitedCurrent)
				{
					hasVisitedCurrent = true;
					return true;
				}

				// Go right, visit the right's left most descendant (or the right node itself)
				if (!hasVisitedRight && current.Right != null)
				{
					current = current.Right;
					MoveToLeftMostDescendant();
					hasVisitedCurrent = true;
					hasVisitedRight = false;
					return true;
				}

				// Move upward
				do
				{
					var wasVisitingFromLeft = current.IsLeftChild;
					current = current.Parent;
					if (wasVisitingFromLeft)
					{
						hasVisitedCurrent = false;
						hasVisitedRight = false;
						return MoveNext();
					}
				} while (current != null);

				return false;
			}

			private void MoveToLeftMostDescendant()
			{
				current = GetLeftMostDescendantOrSelf(current);
			}

			private IntervalNode GetLeftMostDescendantOrSelf(IntervalNode node)
			{
				if (node == null)
					return null;
				while (node.Left != null)
				{
					node = node.Left;
				}
				return node;
			}

			public void Dispose()
			{
			}
		}

		#endregion
	}

	/// <summary>
	/// Selects interval start and end points for an object of type <see cref="TInterval"/>.
	/// </summary>
	/// <typeparam name="TInterval">The type containing interval data</typeparam>
	/// <typeparam name="TPoint">The type of the interval start and end points</typeparam>
	/// <remarks>
	/// In order for the collection using these selectors to be XML serializable, your implementations of this interface must also be
	/// XML serializable (e.g. dont use delegates, and provide a default constructor).
	/// </remarks>
	public interface IIntervalSelector<in TInterval, out TPoint> where TPoint : IComparable<TPoint>
	{
		TPoint GetStart(TInterval item);
		TPoint GetEnd(TInterval item);
	}
}

Here is an example unit test:

[TestFixture]
public class ExampleTestFixture {
	
	[Test]
	public void FindAt_Overlapping ()
	{
		// ARRANGE
		var int1 = new TestInterval (20, 60);
		var int2 = new TestInterval (10, 50);
		var int3 = new TestInterval (40, 70);

		var intervalColl = new[] { 
			int1, int2, int3
		};
		var tree = new IntervalTree<TestInterval, int>(intervalColl, new TestIntervalSelector());

		// ACT
		var res1 = tree.FindAt (0);
		var res2 = tree.FindAt (10);
		var res3 = tree.FindAt (15);
		var res4 = tree.FindAt (20);
		var res5 = tree.FindAt (30);
		var res6 = tree.FindAt (40);
		var res7 = tree.FindAt (45);
		var res8 = tree.FindAt (50);
		var res9 = tree.FindAt (55);
		var res10 = tree.FindAt (60);
		var res11 = tree.FindAt (65);
		var res12 = tree.FindAt (70);
		var res13 = tree.FindAt (75);

		// ASSERT
		Assert.That (res1, Is.Empty);
		Assert.That (res2, Is.EquivalentTo(new[] { int2 }));
		Assert.That (res3, Is.EquivalentTo(new[] { int2 }));
		Assert.That (res4, Is.EquivalentTo(new[] { int1, int2 }));
		Assert.That (res5, Is.EquivalentTo(new[] { int1, int2 }));
		Assert.That (res6, Is.EquivalentTo(new[] { int1, int2, int3 }));
		Assert.That (res7, Is.EquivalentTo(new[] { int1, int2, int3 }));
		Assert.That (res8, Is.EquivalentTo(new[] { int1, int2, int3 }));
		Assert.That (res9, Is.EquivalentTo(new[] { int1, int3 }));
		Assert.That (res10, Is.EquivalentTo(new[] { int1, int3 }));
		Assert.That (res11, Is.EquivalentTo(new[] { int3 }));
		Assert.That (res12, Is.EquivalentTo(new[] { int3 }));
		Assert.That (res13, Is.Empty);
	}
}

[Serializable]
public class TestInterval 
{
	private TestInterval() {}

	public TestInterval(int low, int hi) 
	{
		if(low > hi)
			throw new ArgumentOutOfRangeException("lo higher the hi");
		Low = low;
		Hi = hi;
	}

	public int Low { get; private set; }
	public int Hi { get; private set; }
	public string MutableData { get; set; }

	public override string ToString ()
	{
		return string.Format ("[Low={0}, Hi={1}, Data={2}]", Low, Hi, MutableData);
	}
}

[Serializable]
public class TestIntervalSelector : IIntervalSelector<TestInterval, int>
{
	public int GetStart (TestInterval item) 
	{
		return item.Low;
	}

	public int GetEnd (TestInterval item) 
	{
		return item.Hi;
	}
}

Note: I have rigorously unit tested this.

PS: Sorry but I should have probably uploaded the code to a online source repo + tests (like github)… but I’m too lazy.

XML Serialization for interfaces in .NET

In your quest to find out how you can support XML serialization for types that contain interfaces, you may often find yourself coming to the same answer: you cannot serialize interfaces. That is true, but you can work around it, and I will present two methods.

Write you own Xml serialization code

Refer to : http://social.msdn.microsoft.com/Forums/en-US/bac96f79-82cd-4fef-a748-2a85370a8510/xmlserialization-with-interfaces?forum=asmxandxml

This post provides a nice answer to your woes. But you should watch out for a gotcha: 

strType = m_Child.GetType().FullName;

Instead you may want to use this in case your dependant interface implementations could come from an external assembly:

string strType = m_Child.GetType().AssemblyQualifiedName 
	?? m_Child.GetType().FullName;

Use generics with constraints

You can get away with not having to write your own serialization code by using generics. This isn’t a silver bullet as it will effect your class design, and may not make sense. Take this example:

public class MyClass 
{
	public IMyInterface MyProperty { get; set; }
}

You could rewrite this to become:

public class MyClass <T> where T : IMyInterface
{
	public T MyProperty { get; set; }
}

Of course this has several implications. For example, instantiating these types become more complicated and could lead to tricker wiring/construction problems.  Another point to make is the class itself doesn’t guarantee it is XML serializable: if the class is declared with type (T) of an interface (or a non-serializable class) then it won’t be XML serializable. But this should nothing to be surprized about, the .Net frameworks  XML serializable generic types also behave like this (e.g. List<T>).

Here is another example:

public class MyCollection
{
	public IList<IMyInterface> MyProperty { get; set; }
}

Generic XML serializable version (provided T is XML serializable):

public class MyCollection <T> where T : IMyInterface
{
	public IList<T> MyProperty { get; set; }
}

Programming Praxies – Finding Digit Strings In Powers Of Two

Today’s boredom lead me to solving another programming praxies problem:

Search every power of two below 210000 and return the index of the first power of two in which a target string appears. For instance, if the target is 42, the correct answer is 19 because 219 = 524288, in which the target 42 appears as the third and fourth digits, and no smaller power of two contains the string 42.

The naive solution is simple: keep doubling a number (starting at 1) until you find a sequence that matches the target. Of course once you get 2^64 storing the number as a primitive type will not suffice, where you would need to come up with a solution to store larger numbers.

I  improved on the naive approach by caching the number sequences for each exponent. For the cache I used a type of suffix tree:

private class SuffixTreeNode {
	public short MinExponent;
	public SuffixTreeNode[] Children;
}

Unlike your traditional suffix tree, this one does not compress string sequences. It’s really just a trie, where the number in each sequence is stored implicitly as the index in Children (i.e. these arrays are of length 10). However the data structure is populated and access like a suffix tree: where each suffix of a number sequence is inserted into the tree. Each node in the tree is annotated with the smallest exponent that the sequence can be found in.

Here is the full code:

public static class DigitStringPowerTwoSearch
{
	private static SuffixTreeNode suffixRoot;
	private static short currentExponent;
	private static List<int> currentDigits;

	static DigitStringPowerTwoSearch () {
		FlushCache ();
	}

	internal static void FlushCache() {
		// Exposed as internal really for unit tests only
		currentExponent = 0;
		currentDigits = new List<int> { 1 };
		suffixRoot = new SuffixTreeNode ();
		AddDigitsToTree (currentDigits, 0, 0);
	}

	public static int FindMinBaseTwoExponent(ulong target) {
		var targetDigits = ToDigitArray(target);
		while (currentExponent <= 10000) {
			short exponent = FindMinExponentInTree(targetDigits);
			if (exponent >= 0)
				return exponent;
			AddNextExponentToTree();
		}
		throw new ArgumentOutOfRangeException ("target", 
		                                       "target's digits do not exist in a base two number with exponent " +
		                                       "less or equal to 10,000");
	}

	private static void AddNextExponentToTree() {
		DoubleDigits (currentDigits);
		currentExponent++;

		for (int i = 0; i < currentDigits.Count; i++) {
			AddDigitsToTree (currentDigits, i, currentExponent);
		}
	}

	private static void AddDigitsToTree(IList<int> digits, int startIndex, short exponent) {
		SuffixTreeNode current = suffixRoot;
		for (int i = startIndex; i < digits.Count; i++) {
			int digit = digits [i];
			if (current.Children == null) {
				current.Children = new SuffixTreeNode[10];
			}
			if (current.Children [digit] == null) {
				var newNode = new SuffixTreeNode { MinExponent = exponent };
				current.Children [digit] = newNode;
				current = newNode;
			} else {
				current = current.Children [digit];
				// Here we assume exponent is always the largest exponent,
				// so no need to check/update current.MinExponent
			}
		}
	}

	private static short FindMinExponentInTree(int[] targetDigits) {
		SuffixTreeNode current = suffixRoot;
		foreach(var digit in targetDigits) {
			if (current == null || current.Children == null)
				return -1;
			current = current.Children[digit];
		}
		if (current == null)
			return -1;
		return current.MinExponent;
	}

	private static int[] ToDigitArray(ulong n) {
		if (n == 0) 
			return new int[] { 0 };

		var digits = new List<int>();

		for (; n != 0; n /= 10)
			digits.Add((int)(n % 10));

		var arr = digits.ToArray();
		Array.Reverse(arr);
		return arr;
	}

	private static void DoubleDigits(List<int> digits) {
		bool carry = false;
		for (int i = digits.Count - 1; i >= 0; i--) {
			int d = digits [i] * 2;
			if (carry)
				d++;
			if (d >= 10) {
				d -= 10;
				carry = true;
			} else {
				carry = false;
			}
			digits [i] = d;
		}
		if (carry)
			digits.Insert (0, 1);
	}

	private class SuffixTreeNode {
		public short MinExponent;
		public SuffixTreeNode[] Children;
	}
}

When the function is invoked, if a sequence of numbers have been seen before, then we get O(N) performance (where N = amount of digits in target). Effectively the cache becomes a hash-trie.

A reflection on code reflection: where it helps, and where it hinders.

Today at work I broke some of my team project’s unit tests from a seemingly harmless code change (C#). I simply changed a protected member into an auto-property. Unfortunately the code change was bundled with other changes, for which any other innocent coder would have thought be the changes to blame. But it was the small, innocent (almost code-cosmetic), change that was carried out using a click of the button with Resharper. The test blew up because there was an important piece of code (somewhere) that used reflection to search the class for non-public instance (visible to the class type) member variables and collate members that inherited a specific interface. What is this!? Some silent protocol?? How fragile!

Image

Reflection is a powerful tool that has blessed us with many awesome easy to use API’s. But clearly, it is not suitable to solve all problems. So when should we use reflection? And when should we avoid it? Here are a few common pros and cons that tend to crop up around the topic of reflection (this is not by all means a comprehensive list!):

Good reflection 🙂

  • Dependency injection frameworks.
    Reflection has given us killer IoC tools like Windsor and Unity to solve our dependency problems.
    Clearly refection is a key enabler in the technology, as dependencies and instaintiation is all achievable via binary metadata analysis.
  • Plugin frameworks.
    Plugins frameworks commonly use reflection to dynamically load 3rd party plugins, which it could not do so easily without dynamically loading the additional libraries via reflection.

Bad reflection 😦

  • Refactoring tools and code analysis tools are incompatible.
    The opening example of this post shows that refactoring tools cannot cover what reflection can do: it can make your code brittle. It’s much better to be explicit with your code design; avoid establishing implicit protocols in your code base which your reflection code requires in order to work correctly. Note that static code analysis tools such as refactoring tools, or features like discovering method usage (e.g. with the reshaper tool) are rendered useless with code that uses reflection. This is a dangerous place to be.
  • Adds a super-generic layer of indirection.
    Indirection is a double edged sword: it can improve the design (and yield a number of benefits), but with the cost of adding complexity. The problem with reflection is that it adds a higher degree of indirection than non-reflective code, because it hides static detail such as class names, method names, and property names.  So heavy use of reflection makes the program near impossible to run through static code walk throughs. It also can be very difficult to debug. 
  • Run-time errors instead of compile-time errors.
    This argument can be used for all sorts of mechanisms (such as dynamic type-checking features), but it is a good point to make.  If you have the option of a design that doesn’t require reflection, at least you have a chance your compiler will complain if code changes have broken something. A design using reflection is subject to runtime errors, which in the worst case may not be detected until a release cycle (or in production!).
  • Invocation via reflection is much slower.
    Generally the performance hit from reflection is neglectable, but in sections of code where reflection is used heavily performance will degrade. Performance is much slower in reflection because during invocation the binaries metadata must be inspected at runtime (rather than being precompiled at compile time).

Conclusion

Avoid reflection.

If you think you need to solve a problem by reflection, rework the design (don’t be lazy!). Also don’t use reflection to get at protected data (e.g. non-public members), violating standard language conventions will get you into all sorts of trouble further down the line. Only use reflection where it is absolutely the only way to meet your needs. An acceptable place to use reflection is where there would be no way – at least without enforcing a difficult/cumbersome protocol to adhere to – to implement a solution. So be wary of the drawbacks of reflection before you get crazy with it, and always strive for a solid design!

Non-blocking Writer Collection (C# Example)

A question was posed to me recently:

If you had a thread that produced messages, and pushed those messages one or more consumer threads: how would you write the code to ensure that the producer thread executes as fast as possible? (I.E. no blocking on the producer thread).

An interesting problem to solve! Before looking at the suggested solution below, try and have a go at coming up with your own solution.

Let me present one way to go about this problem. The general idea is to use two message queues: a write queue, and a read queue. The writer (producer) thread always writes to the write queue. The reader thread always reads from the read queue. When the reader thread exhausts the read queue, the reader switches the queues: the write queue becomes the read queue, and vice versa. If the reader finds that the new reader queue (after the switch) is also empty, then it waits until the writer thread writes to the write queue. The reader thread then switches the queues again, and reads from the new read queue to extract the most recent item.

Note: I’ve choicely chosen the name “Collection” for this ADT, since the implementation does not guarantee ordering. It is sort of FIFO, but not strictly (otherwise I would of called it a NonBlockingWriterQueue!). Here is a C# example:

 


public class NonBlockingWriterCollection<T> : IDisposable {

	private Queue<T> leftQueue = new Queue<T>();
	private Queue<T> rightQueue = new Queue<T>();

	private volatile bool isWriting = false;
	private volatile bool writeToLeft = true;
	private object readerLocker = new object ();
	private EventWaitHandle readerWaitHandle = new AutoResetEvent (false);

	private volatile bool disposed = false;

	~NonBlockingWriterCollection()
	{
		Dispose(false);
	}

	#region IDisposable
	public void Dispose()
	{
		Dispose(true);
		GC.SuppressFinalize(this);
	}

	protected virtual void Dispose(bool disposing)
	{
		if(!disposed) {
			disposed = true;
			// Signal reader to wake up. Since closing/disposing the event handle doesnt raise
			// object disposed exceptions for threads in waiting states.
			readerWaitHandle.Set (); 
			if(disposing) {
				readerWaitHandle.Close (); // Can cause in-progress writes to throw a disposed exception.
			}

		}
	}
	#endregion

	/// <summary>
	/// Enqueues the specified item.
	/// </summary>
	/// <param name="item">Item to enqueue</param>
	/// <exception cref="ObjectDisposedException">If the queue has been disposed prior to executing this methed</exception>
	/// <remarks>
	/// Only a single writer thread can enqueue.
	/// This operation is non-blocking.
	/// </remarks>
	public void Write(T item) {
		if (disposed)
			throw new ObjectDisposedException ("NonBlockingWriterQueue");

		try {
			isWriting = true;

			// Queue an item on the write queue
			if(writeToLeft)
				leftQueue.Enqueue(item);
			else
				rightQueue.Enqueue(item);

			// Signal reader thread that an item has been added
			readerWaitHandle.Set();

		} finally {
			isWriting = false;
		}

	}

	/// <summary>
	/// Dequeues an item.
	/// </summary>
	/// <exception cref="ObjectDisposedException">If the queue has been disposed prior-to or during executing this methed</exception>
	/// <remarks>
	/// Multiple reader threads can attempt to dequeue an item.
	/// This operation is blocking (until an item has been enqueued, or the collection has been disposed).
	/// </remarks>
	public T Read() {
		lock (readerLocker) {
			if (disposed)
				throw new ObjectDisposedException ("NonBlockingWriterQueue");

			// Reset the wait handle, at this point we are searching for an item on either queue
			readerWaitHandle.Reset ();

			// Dequeue an item from the queue that is not being written to
			var readQueue = writeToLeft ? rightQueue : leftQueue;
			if (readQueue.Count > 0)
				return readQueue.Dequeue ();

			while (!disposed) {
				// The read queue has been exhausted. Swap read/write queue
				writeToLeft = !writeToLeft;

				// At this point, the writer thread could be writing to either queue, 
				// wait for the write to finish using a spin lock
				while (isWriting) {
				} // busy waiting

				// Try read again from the read queue
				readQueue = writeToLeft ? rightQueue : leftQueue;
				if (readQueue.Count > 0)
					return readQueue.Dequeue ();

				// Both lists have been exhausted, we need to wait for the writer to
				// do something. Block the reader until the writer has signalled.
				// Note: it may have added an item during the read... so this may
				// not block, and continue the read right away
				readerWaitHandle.WaitOne ();
			}

			throw new ObjectDisposedException ("NonBlockingWriterQueue");
		}
	}
}

If you know you are only going to have a single reader, then you can simply remove the reader mutex to improve performance.

Note the uses of volatile members. Volatile read/writes are required for these primitive members, otherwise all-sorts of chaos could happen should the compiler choose to re-arrange/cache the read/write instructions.

And here is a little test harness:

private const int ReaderCount = 5;
private NonBlockingWriterCollection<int> nbQueue;

public void Run ()
{
	Console.WriteLine ("Starting sim...");
	Thread[] readerThreads;
	Thread writerThread;
	using(nbQueue = new NonBlockingWriterCollection<int>()) {

		readerThreads = new Thread[ReaderCount];
		for(int i = 0; i < ReaderCount; i++) {
			readerThreads [i] = new Thread (RunReader);
			readerThreads [i].Start (i); // box i
		}

		writerThread = new Thread (RunWriter);
		writerThread.Start ();

		Thread.Sleep (1000 * 5);
	}

	Console.WriteLine ("Waiting for sim threads to finish...");
	writerThread.Join ();
	foreach (var rt in readerThreads)
		rt.Join ();

	Console.WriteLine ("Finished sim.");
}

private void RunReader(object threadNum) {
	var rand = new Random (((int)threadNum) * 6109425);
	string threadId = "ReaderThread " + (char)('A' + (int)threadNum); // unbox i
	try {
		while (true) {
			var item = nbQueue.Read (); // blocking
			Console.WriteLine (threadId + " read " + item);
			Thread.Sleep(rand.Next() % 10); // Do some "work" to process the data
		}
	} catch(ObjectDisposedException) {}
}

private void RunWriter() {
	string threadId = "WriterThread";
	var rand = new Random ();
	try {
		while (true) {
			for(int i = 1; i < rand.Next () % 10; i++) {
				var item = rand.Next ();
				Console.WriteLine (threadId + " writing " + item);
				nbQueue.Write (item); // non-blocking
			}
			Thread.Sleep(10 + rand.Next() % 100); // Do some "work" to produce more data
		}
	} catch(ObjectDisposedException) {}
}

Programming Praxies – Egyptian Fractions, C# solution.

This post presents C# solutions to a coin change problem as described in http://programmingpraxis.com/2013/06/04/egyptian-fractions.

An Egyptian fraction was written as a sum of unit fractions, meaning the numerator is always 1; further, no two denominators can be the same. As easy way to create an Egyptian fraction is to repeatedly take the largest unit fraction that will fit, subtract to find what remains, and repeat until the remainder is a unit fraction. For instance, 7 divided by 15 is less than 1/2 but more than 1/3, so the first unit fraction is 1/3 and the first remainder is 2/15. Then 2/15 is less than 1/7 but more than 1/8, so the second unit fraction is 1/8 and the second remainder is 1/120. That’s in unit form, so we are finished: 7 ÷ 15 = 1/3 + 1/8 + 1/120. There are other algorithms for finding Egyptian fractions, but there is no algorithm that guarantees a maximum number of terms or a minimum largest denominator; for instance, the greedy algorithm leads to 5 ÷ 121 = 1/25 + 1/757 + 1/763309 + 1/873960180913 + 1/1527612795642093418846225, but a simpler rendering of the same number is 1/33 + 1/121 + 1/363.

Your task is to write a program that calculates the ratio of two numbers as an Egyptian fraction…

As presented in the original post, you can use a greedy algorithm. Thats not fun! Let’s try and improve it: i.e. to return the smallest amount of unit fractions. You cannot devise an algorithm to guarantee the smallest amount of terms, since the problem space has infinite possibilities (genetic algorithm? food for thought).

So what approach can we take to improve on the greedy solution? I began by writing out a few iterations of calculations for 5 ÷ 121: from 1/25 up to 1/33 (part of an optimal solution presented in the problem definition). I noticed that when choosing 1/33, the remaining fraction (that we are pulling apart) can be simplified: where as all other fractions leading up to 1/33 leaves a fraction that cannot be further simplified! If you think about it, simplifying keeps the size of the denominator down, keeping smaller denominators helps yields a smaller amount of terms. This is because when we are dealing with very small numbers (large denominators), we are getting to the final target at a slower rate that we would with larger numbers. Simple huh?

So how can you simplify a fraction? It can be done by calculating the gcd (greatest common divisor) between the numerator and the denominator, then dividing the numerator and the denominator by the gcd. If the gcd is 1, then the fraction cannot be simplified. Hmmm… so maybe we can decide on the next unit fraction (for subtracting) only if the result can be simplified.  Using this informal idea as the basis of our algorithm, we get the following solution:

public class EgyptionFractions
{
	public static List<int[]> GetFractions(int numerator, int denominator) {
		if (numerator >= denominator)
			throw new ArgumentOutOfRangeException ("denominator");
		if (numerator <= 0)
			throw new ArgumentOutOfRangeException ("numerator");

		var fractions = new List<int[]> ();
		int subDenominator = 2;

		do {
			// First find the next fraction to substract from that is small enough
			int leftNumerator = numerator * subDenominator;
			while (leftNumerator < denominator) { // Note: rightNumerator == denominator
				subDenominator++;
				leftNumerator += numerator;
			}

			// Now we have a valid unit fraction to substract with, lets continue
			// searching for the next unit fraction that yeilds a remainder that 
			// can be simplified (to keep the denominators small).
			while (true) {
				int remainingNumerator = leftNumerator - denominator;
				if(remainingNumerator == 0) {
					// The fractions are the same
					numerator = 0;
					fractions.Add (new [] {1, subDenominator});
					break;
				}
				int remainingDenominator = denominator * subDenominator;
				int gcd = GCD (remainingNumerator, remainingDenominator);
				if (gcd > 1 || remainingNumerator == 1) {
					// The resultant fraction can be simplified using this denominator
					numerator = remainingNumerator / gcd;
					denominator = remainingDenominator / gcd;
					fractions.Add (new [] {1, subDenominator});

					// Finished?
					if(numerator == 1) 
						fractions.Add (new [] {1, denominator});
					break;
				}
				subDenominator++;
				leftNumerator += numerator; // i.e. additive version of subDenominator * numerator;
			}

			subDenominator++;
		} while (numerator > 1);

		return fractions;
	}

	private static int GCD(int n1, int n2) {
		if (n2 == 0)
			return n1;
		return GCD (n2, n1 % n2);
	}
}

If you pass in 5, 121, the result will be:
1/33, 1/91, 1/33033