Saturday, April 4, 2015

Partitioning an enumearble into fixed size chunks

Recently at work, i was doing some db querying for items having a list of ids in the hand. Since I know that MSSQL Server allows only a limited number of parameters, i had to split my ids into fixed chunks, and execute several queries. The simple task of splitting an IEnumearble to fixed size chunks turned out to be not so straightforward, so i decided it is worth writing a short blog post about it. My initial solution was along the lines of:
     public static IEnumerable<IEnumerable<T>> Partition<T>(this IEnumerable<T> source, int size)  
     {  
       // check arguments ...  
       var result = new List<T>(size);  
       foreach (var item in source)  
       {  
         result.Add(item);  
         if (result.Count == size)  
         {  
           yield return result;  
           result = new List<T>(size);  
         }  
       }  
       if (result.Count > 0)  
         yield return result;  
     }  
Now this seems quite straightforward, but I didn't like it too much, because it eagerly fills a result list before returning it. I was sure there is a lazier solution, so googled the thing, and got pretty surprised, that on SO most of the answers were either as bad as my initial solution, or even much worse.

My disappointment quickly turned into the decision, that I will do an implementation that satisfies what I would expect of this method, and it's result. I'll let you follow me through with the assumptions and the way I think is the best to grab this problem.
I want to simply write a unit test for each assumption I am making (I use NUnit), and then provide a simple solution that passes all the tests, and then invite you to make it more elegant/compact/readable/professional, or whatever else you think could be done to the code. So lets start:

1) I always like to start with my corner cases, as they are the easiest to implement, and once you implemented a functionality, it is boring to add them. Somehow before I don't mind that much. First assumption: for null enumerable I want an appropriate exception.
     [Test]  
     public void Partition_ForNullEnumerable_ThrowsArgumentNullException()  
     {  
       IEnumerable<long> source = null;  
       var ex = Assert.Throws<ArgumentNullException>(() => source.Partition(5).ToList(), "null source is not allowed");  
       Assert.AreEqual("source", ex.ParamName, "Error should reflect input field");  
     }  
Note that I need to call tolist, otherwise my method is never actually called.

2) For partition size of less than 1 I want to have an appropriate error:
     [Test]  
     public void Partition_ForPartitionSizeBelow1_ThrowsArgumentException()  
     {  
       var source = new List<long>();  
       var ex = Assert.Throws<ArgumentException>(() => source.Partition(0).ToList(), "partition cannot be below 1");  
       StringAssert.Contains("0", ex.Message, "Message should reflect error");  
       StringAssert.Contains("size", ex.Message, "Message should reflect error");  
     }  
so far so good.

3) So now that we have the boring stuff behind us, let's write a series of tests that reflect how I expect my method to behave. Since most of my tests will test lazyness, I will make an enumerable, where I can easily check which items have been "produced". I will use a simple context class:
     class TestContext  
     {  
       /// <summary>  
       /// the collection of all our letters  
       /// </summary>  
       public const string SourceLetters = "0123456789ABCDEF";  
       public IEnumerable<char> Source { get; private set; }  
       public string ProducedItems { get { return builder.ToString(); } }  
       public TestContext()  
       {  
         builder = new StringBuilder();  
         Source = SourceLetters  
           .ToCharArray()  
           .Select(x =>  
           {  
             builder.Append(x);  
             return x;  
           });  
       }  
       private StringBuilder builder;  
     }  
     private TestContext GetContext()  
     {  
       return new TestContext();  
     }  
This way i can easily assert on the ProducedItems.

4) Next assumption: just calling my method should not produce anything.
     [Test]  
     public void Partition_CallingMethod_DoesNotEnumerateSource()  
     {  
       var ctx = GetContext();  
       var result = ctx.Source.Partition(5);  
       Assert.AreEqual(string.Empty, ctx.ProducedItems, "No items should have been produced before enumeration");  
     }  
Since I think all of the solutions using yield return would pass this, lets quickly move on.

5) Check for correctness: Partitioning to the whole set will give me the correct items.
     [Test]  
     public void Partition_ToSizeOfSet_GivesCorrectResult()  
     {  
       var ctx = GetContext();  
       var result = ctx.Source.Partition(TestContext.SourceLetters.Length);  
       Assert.AreEqual(1, result.Count(), "we should get only one partition");  
       CollectionAssert.AreEqual(TestContext.SourceLetters.ToCharArray(), result.First().ToArray(), "The first partition should have all the items");  
     }  

6) what about when size is even bigger?
     [Test]  
     public void Partition_ToSizeBiggerThanSet_GivesCorrectResult()  
     {  
       var ctx = GetContext();  
       var result = ctx.Source.Partition(TestContext.SourceLetters.Length + 5);  
       Assert.AreEqual(1, result.Count(), "we should get only one partition");  
       CollectionAssert.AreEqual(TestContext.SourceLetters, result.First(), "The first partition should have all the items");  
     }  

7) ok, now partition to equal parts, that divide the total count
     [Test]  
     public void Partition_To4Parts_GivesCorrectResult()  
     {  
       var ctx = GetContext();  
       var result = ctx.Source.Partition(4);  
       Assert.AreEqual(4, result.Count(), "splitting 16 letters into sizes of 4 should result in 4 partitions");  
       CollectionAssert.AreEqual(TestContext.SourceLetters.Substring(0, 4).ToCharArray(), result.First().ToArray(), "The first partition has invalid result");  
       CollectionAssert.AreEqual(TestContext.SourceLetters.Substring(4, 4).ToCharArray(), result.Skip(1).First().ToArray(), "The second partition has invalid result");  
       CollectionAssert.AreEqual(TestContext.SourceLetters.Substring(8, 4).ToCharArray(), result.Skip(2).First().ToArray(), "The third partition has invalid result");  
       CollectionAssert.AreEqual(TestContext.SourceLetters.Substring(12, 4).ToCharArray(), result.Skip(3).First().ToArray(), "The fourth partition has invalid result");  
     }  

8) sweet, and the last: partition to parts that don't divide exactly
     [Test]  
     public void Partition_ToPartsOf5_GivesCorrectResult()  
     {  
       var ctx = GetContext();  
       var result = ctx.Source.Partition(5);  
       Assert.AreEqual(4, result.Count(), "splitting 16 letters into sizes of 5 should result in 4 partitions");  
       CollectionAssert.AreEqual(TestContext.SourceLetters.Substring(0, 5), result.First(), "The first partition has invalid result");  
       CollectionAssert.AreEqual(TestContext.SourceLetters.Substring(5, 5), result.Skip(1).First(), "The second partition has invalid result");  
       CollectionAssert.AreEqual(TestContext.SourceLetters.Substring(10, 5), result.Skip(2).First(), "The third partition has invalid result");  
       CollectionAssert.AreEqual(TestContext.SourceLetters.Substring(15, 1), result.Skip(3).First(), "The fourth partition has invalid result");  
     }  
So far so good. If a solution doesn't pass all these tests, then it is definately wrong.

9) Now I want to make sure, that I get correct results even if i enumerate my results in a different order than expected.
     [Test]  
     public void Partition_EnumeartingInReverseOrder_GivesCorrectResult()  
     {  
       var ctx = GetContext();  
       using (var enumerator = ctx.Source.Partition(8).GetEnumerator())  
       {  
         enumerator.MoveNext(); // go to first element  
         var result1 = enumerator.Current;  
         enumerator.MoveNext(); // go to second element  
         var result2 = enumerator.Current;  
         CollectionAssert.AreEqual(TestContext.SourceLetters.Substring(8, 8).ToCharArray(), result2.ToArray(), "The second partition has invalid result");  
         CollectionAssert.AreEqual(TestContext.SourceLetters.Substring(0, 8).ToCharArray(), result1.ToArray(), "The first partition has invalid result");  
       }  
     }  
The funny thing is, I have seen accepted SO answers, that would not pass this test.

Now it is time to start testing the lazyness of our solution. This is where my initial solution starts to fail, and where I will have to start writing a new solution.

10) So if i get the first partition, I want to produce only the first item from the original enumerable. I need to produce it, otherwise the system cannot know, if there is at least one partition, or none...
     [Test]  
     public void Partition_GettingFirstPartition_OnlyProduces1Item()  
     {  
       var ctx = GetContext();  
       using (var enumerator = ctx.Source.Partition(8).GetEnumerator())  
       {  
         enumerator.MoveNext(); // go to first element  
         var result1 = enumerator.Current;  
         Assert.AreEqual("0", ctx.ProducedItems, "Only the first item should have been produced");  
       }  
     }  
Here we will already fail. Our initial solution will produce all 8 items in the partition, before returning it. And this is exactly my concern. If the "production" of these items is expensive, and I am looking only for the first item that fulfills my conditions, I do not want to produce all items in the partition. Makes no sense.

11) Now I want to test if I start enumerating the first partition, getting the first item will not reproduce that item, and will not produce any other ones.
     [Test]  
     public void Partition_GettingFirstElementOfFirstPartition_OnlyProduces1Item()  
     {  
       var ctx = GetContext();  
       using (var enumerator = ctx.Source.Partition(8).GetEnumerator())  
       {  
         enumerator.MoveNext(); // go to first element  
         var result1 = enumerator.Current;  
         using(var enumerator1 = result1.GetEnumerator()){  
           enumerator1.MoveNext(); // go to first element  
           var result1_1 = enumerator1.Current;  
         }  
         Assert.AreEqual("0", ctx.ProducedItems, "Only the first item should have been produced");  
       }  
     }  

12) I want to make sure, that when I enumerate the first partition, I am producing the items in a lazy manner. So that getting the second Item from the first partition creates only 2 items:
     [Test]  
     public void Partition_GettingFirst2ElementsOfFirstPartition_OnlyProduces2Items()  
     {  
       var ctx = GetContext();  
       using (var enumerator = ctx.Source.Partition(8).GetEnumerator())  
       {  
         enumerator.MoveNext(); // go to first element  
         var result1 = enumerator.Current;  
         using(var enumerator1 = result1.GetEnumerator()){  
           enumerator1.MoveNext(); // go to first element  
           enumerator1.MoveNext(); // go to second element  
           var result1_1 = enumerator1.Current;  
         }  
         Assert.AreEqual("01", ctx.ProducedItems, "Only the first 2 items should have been produced");  
       }  
     }  

13) and the same for the third
     [Test]  
     public void Partition_GettingFirst3ElementsOfFirstPartition_OnlyProduces3Items()  
     {  
       var ctx = GetContext();  
       using (var enumerator = ctx.Source.Partition(8).GetEnumerator())  
       {  
         enumerator.MoveNext(); // go to first element  
         var result1 = enumerator.Current;  
         using(var enumerator1 = result1.GetEnumerator()){  
           enumerator1.MoveNext(); // go to first element  
           enumerator1.MoveNext(); // go to second element  
           enumerator1.MoveNext(); // go to third element  
           var result1_1 = enumerator1.Current;  
         }  
         Assert.AreEqual("012", ctx.ProducedItems, "Only the first 3 items should have been produced");  
       }  
     }  

14) now I want to make sure, that enumerating through the first partition will not produce any items from the second partition, as it is not necessary
     [Test]  
     public void Partition_EnumeratingAllFirstPartition_OnlyProducesItemsFromThere()  
     {  
       var ctx = GetContext();  
       using (var enumerator = ctx.Source.Partition(3).GetEnumerator())  
       {  
         enumerator.MoveNext(); // go to first element  
         var result1 = enumerator.Current.ToList();  
         Assert.AreEqual("012", ctx.ProducedItems, "Only the first 3 items should have been produced, as that is our partition size");  
       }  
     }  

15) so far so good. I understand that to get to the second partition, all items of the first partition have to be produced, since an enumerable only allows goind forward by one, but I want to make sure, that the second partition is not enumerated to the end as the first one
     [Test]  
     public void Partition_GettingSecondPartition_OnlyProducesItemsToTheFirstItemOfPartition2()  
     {  
       var ctx = GetContext();  
       using (var enumerator = ctx.Source.Partition(3).GetEnumerator())  
       {  
         enumerator.MoveNext(); // go to first partition  
         enumerator.MoveNext(); // go to second partition  
         var result2 = enumerator.Current;  
         Assert.AreEqual("0123", ctx.ProducedItems, "Only the first 4 items should have been produced, as the second partition starts at item 4");  
       }  
     }  

16) now i want to have the same lazy production for partition 2 as for partition 1
     [Test]  
     public void Partition_EnumerationgSecondPartition_ProducesItemsLazily()  
     {  
       var ctx = GetContext();  
       using (var enumerator = ctx.Source.Partition(4).GetEnumerator())  
       {  
         enumerator.MoveNext(); // go to first partition  
         enumerator.MoveNext(); // go to second partition  
         using (var enumerator2 = enumerator.Current.GetEnumerator())  
         {  
           enumerator2.MoveNext(); // go to first item  
           enumerator2.MoveNext(); // go to second item  
           enumerator2.MoveNext(); // go to third item  
           var result2_3 = enumerator.Current;  
         }  
         Assert.AreEqual("0123456", ctx.ProducedItems, "Only the first 7 items should have been produced");  
       }  
     }  

17) and now I want to make sure, that even if I enumerate first the second partition, and then the first, the items of the first partition are not reproduced (meaning my original enumeratble is not restarted). I think in most of the cases, that is the expected behavior, although I could also imagine, that when the items are cheap to produce, but require a lot of memory, then maybe the expected behavior would be to reproduce the items. But in our usecase here, we assume that producing is more expensive than having the items around.
     [Test]  
     public void Partition_EnumerationgFirstPartitionAfterSecond_DoesNotReproduceItem()  
     {  
       var ctx = GetContext();  
       using (var enumerator = ctx.Source.Partition(4).GetEnumerator())  
       {  
         enumerator.MoveNext(); // go to first partition  
         var result1 = enumerator.Current;  
         enumerator.MoveNext(); // go to second partition  
         var result2 = enumerator.Current.ToList(); // enumerate 2nd partition  
         var list1 = result1.ToList(); // enumerate first partition  
         Assert.AreEqual("01234567", ctx.ProducedItems, "The first partition items should not be reproduced");  
       }  
     }  

18) now we are alredy quite far in our requirements. I would actually also like to reenumerate any partition without enumerating my original source, so no reproduction of items.
     [Test]  
     public void Partition_MultiplePartitionEnumeartion_DoesNotReproduceItem()  
     {  
       var ctx = GetContext();  
       using (var enumerator = ctx.Source.Partition(4).GetEnumerator())  
       {  
         enumerator.MoveNext(); // go to first partition  
         var result1 = enumerator.Current;  
         var list1 = result1.ToList(); // enumerate first partition  
         var list2 = result1.ToList(); // enumerate first partition again  
         Assert.AreEqual("0123", ctx.ProducedItems, "The first partition items should not be reproduced");  
       }  
     }  

I think this would conclude my assumptions about the partitioning. What about the solution? I would invite you to take these tests, and come up with a method that passes all of them.
If you are done with it, you can come back, and compare it to my solution. If you find yours compacter, easier to understand, or simply more elegant, I would be very interested to see your code.

Mine looks like:
   public static class Extensions  
   {  
     /// <summary>  
     /// cached enumeration with push possibilities  
     /// </summary>  
     private class CachedEnumeration<T> : IEnumerable<T>  
     {  
       /// <summary>  
       /// enumerator for the cachedEnumeration class  
       /// </summary>  
       class CachedEnumerator : IEnumerator<T>  
       {  
         private readonly CachedEnumeration<T> m_source;  
         private int m_index;  
         public CachedEnumerator(CachedEnumeration<T> source)  
         {  
           m_source = source;  
           // start at index -1, since an enumerator needs to start with MoveNext before calling current  
           m_index = -1;  
         }  
         public T Current  
         {  
           get { return m_source.m_items[m_index]; }  
         }  
         public void Dispose()  
         {  
         }  
         object System.Collections.IEnumerator.Current  
         {  
           get { return Current; }  
         }  
         public bool MoveNext()  
         {  
           // if we have cached items, just increase our index  
           if (m_source.m_items.Count > m_index + 1)  
           {  
             m_index++;  
             return true;  
           }  
           else  
           {  
             var result = m_source.FetchOne();  
             if (result)  
               m_index++;  
             return result;  
           }  
         }  
         public void Reset()  
         {  
           m_index = -1;  
         }  
       }  
       /// <summary>  
       /// list containing all the items  
       /// </summary>  
       private readonly List<T> m_items;  
       /// <summary>  
       /// callback how to fetch an item  
       /// </summary>  
       private readonly Func<Tuple<bool, T>> m_fetchMethod;  
       private readonly int m_targetSize;  
       public CachedEnumeration(int size, T firstItem, Func<Tuple<bool, T>> fetchMethod)  
       {  
         m_items = new List<T>(size);  
         m_items.Add(firstItem);  
         m_fetchMethod = fetchMethod;  
         m_targetSize = size;  
       }  
       public IEnumerator<T> GetEnumerator()  
       {  
         return new CachedEnumerator(this);  
       }  
       System.Collections.IEnumerator System.Collections.IEnumerable.GetEnumerator()  
       {  
         return GetEnumerator();  
       }  
       private bool FetchOne()  
       {  
         if (IsFull)  
           return false;  
         var result = m_fetchMethod();  
         if (result.Item1)  
           m_items.Add(result.Item2);  
         return result.Item1;  
       }  
       /// <summary>  
       /// fetches all items to the cached enumerable  
       /// </summary>  
       public void FetchAll()  
       {  
         while (FetchOne()) { }  
       }  
       /// <summary>  
       /// tells weather the enumeration is already full  
       /// </summary>  
       public bool IsFull { get { return m_targetSize == m_items.Count; } }  
     }  
     /// <summary>  
     /// partitions the <paramref name="source"/> to parts of size <paramref name="size"/>  
     /// </summary>  
     public static IEnumerable<IEnumerable<T>> Partition<T>(this IEnumerable<T> source, int size)  
     {  
       if (source == null)  
         throw new ArgumentNullException("source");  
       if (size < 1)  
         throw new ArgumentException(string.Format("The specified size ({0}) is invalid, it needs to be at least 1.", size), "size");  
       var enumerator = source.GetEnumerator();  
       while (enumerator.MoveNext())  
       {  
         var lastResult = new CachedEnumeration<T>(size, enumerator.Current, () => Tuple.Create(enumerator.MoveNext(), enumerator.Current));  
         yield return lastResult;  
         lastResult.FetchAll();  
       }  
     }  
   }  
What do you think?
I dont like too much, that the enumerator is not in a using statement, but if it would be in there, then the call to .First() on the result of this method would destroy the enumerator in the using, and then the first set would not produce any more items. Maybe somehow it is possible to wrap the enumerator in another Disposable class, that would remember somehow that this enumerator is in use somewhere else, and then would allow disposing of the enumerator, if it is not needed anymore at all, and not just let it hang waiting for the GC to find it....
I'll explore it in another blog entry :)

Well, that was it for today. You can download the source code here.
And as always, I'm open for feedback!!!!

Thanks for reading.

Wednesday, June 4, 2014

Retrieving Generic Method Definitions

I had encountered this problem at least a couple of times, and have never really found an elegant solution for it, up until a couple of days ago. The problem is very simple. I need to for example retrieve the
public static IQueryable<IGrouping<TKey, TSource>> GroupBy<TSource, TKey>(this IQueryable<TSource> source, Expression<Func<TSource, TKey>> keySelector)  
method definition. Since the method has quite a lot of generic arguments, and they are themselves complex types, it is quite a pain to retrieve it with a reflection search. When I googled, i found solutions like:
           var methodDefinition = typeof(Queryable).GetMethods()  
               .Where(x => x.Name == "GroupBy")  
               .Select(x => new { M = x, P = x.GetParameters() })  
               .Where(x => x.P.Length == 2  
                 && x.P[0].ParameterType.IsGenericType  
                 && x.P[0].ParameterType.GetGenericTypeDefinition() == typeof(IQueryable<>)  
                 && x.P[1].ParameterType.IsGenericType  
                 && x.P[1].ParameterType.GetGenericTypeDefinition() == typeof(Expression<>))  
               .Select(x => new { x.M, A = x.P[1].ParameterType.GetGenericArguments() })  
               .Where(x => x.A[0].IsGenericType  
                 && x.A[0].GetGenericTypeDefinition() == typeof(Func<,>))  
               .Select(x => new { x.M, A = x.A[0].GetGenericArguments() })  
               .Where(x => x.A[0].IsGenericParameter  
                 && x.A[1].IsGenericParameter)  
               .Select(x => x.M)  
               .SingleOrDefault();  
I mean really? are you kidding me? There has to be an easier way. This is waaaaay to complicated.
Well, i have good news, it is all not needed. Why would we need to do all this reflection heavy lifting, if Microsoft had made a tool that already does this? Yep, you guessed, it is the compiler.

Elegant solution

So what am I talking about? Very simple.
When you define a lambda expression, the compiler actually builds up the expression for you. A lambda expression is the definition of the code inside, and if you build an expression with the desired method, well, it will be in the expression for you to retrieve.
So how does the other code translate?
           Expression<Func<IQueryable<string>, IQueryable<IGrouping<bool, string>>>> fakeExp = q => q.GroupBy(x => x.StartsWith(string.Empty));  
           var methodDefinition = ((MethodCallExpression)fakeExp.Body).Method.GetGenericMethodDefinition();  
Since from the expression we use only the method definition, obviously the types do not matter. We get the generic method definition, and we get from it the desired version by calling the MakeGenericMethod method. The generic method definition you can statically cache, and then create a simple method that will get you the desired type, like:
       private static MethodInfo _queryableGroupByMethod;  
       public static MethodInfo QueryableGroupByMethod  
       {  
         get  
         {  
           if (_queryableGroupByMethod == null)  
           {  
             Expression<Func<IQueryable<string>, IQueryable<IGrouping<bool, string>>>> fakeExp = q => q.GroupBy(x => x.StartsWith(string.Empty));  
             _queryableGroupByMethod = ((MethodCallExpression)fakeExp.Body).Method.GetGenericMethodDefinition();  
           }  
           return _queryableGroupByMethod;  
         }  
       }  
       public static MethodInfo MakeQueryableGroupByMethod(Type source, Type destination)  
       {  
         return QueryableGroupByMethod.MakeGenericMethod(source, destination);  
       }  
That's it. I guess the pattern is clear, you can use it for any other method. No strings, no comlicated selects and wheres. Just pure, compiler checked goodies :)

Tuesday, February 19, 2013

Memorizing method results with PostSharp (part 3)

This is the 3rd episode of a series of posts on how to cache statically method results.
In the first post we discussed what is postsharp, and how can it be used to create an attribute that will cache the results.
In the second post we removed hard references to the ibject instance, and to the arguments.
In this third post, we will focus on the missing limitations, namely generics and null parameters.
Lets jump right into it then.

Generic methods

Well, to make sure that generic methods do not mess up the picture, lets create a generic method inside our TestClass:
        public class TestClass
        {
            // previous code
            [...]

            public Dictionary<Type, int> NrOfGenericExecutions = new Dictionary<Type,int>();
            [CacheResultAttribute.CacheResult(CacheLocations.Static)]
            public int DoubleTheNumber<T>(int i)
            {
                if (!NrOfGenericExecutions.ContainsKey(typeof(T)))
                {
                    NrOfGenericExecutions.Add(typeof(T), 0);
                }
                NrOfGenericExecutions[typeof(T)] = NrOfGenericExecutions[typeof(T)] + 1;
                return i * 2;
            }
        }
and then write a test for it:
        [TestMethod]
        public void TestGenericMethods()
        {
            var t = new TestClass();
            Assert.AreEqual(0, t.NrOfGenericExecutions.Count);
            t.DoubleTheNumber<int>(2);
            Assert.AreEqual(1, t.NrOfGenericExecutions.Count);
            Assert.AreEqual(1, t.NrOfGenericExecutions[typeof(int)]);
            t.DoubleTheNumber<int>(3);
            Assert.AreEqual(1, t.NrOfGenericExecutions.Count);
            Assert.AreEqual(2, t.NrOfGenericExecutions[typeof(int)]);
            t.DoubleTheNumber<int>(2);
            Assert.AreEqual(1, t.NrOfGenericExecutions.Count);
            Assert.AreEqual(2, t.NrOfGenericExecutions[typeof(int)]);

            t.DoubleTheNumber<string>(2);
            Assert.AreEqual(2, t.NrOfGenericExecutions.Count);
            Assert.AreEqual(2, t.NrOfGenericExecutions[typeof(int)]);
            Assert.AreEqual(1, t.NrOfGenericExecutions[typeof(string)]);
            t.DoubleTheNumber<string>(3);
            Assert.AreEqual(2, t.NrOfGenericExecutions.Count);
            Assert.AreEqual(2, t.NrOfGenericExecutions[typeof(int)]);
            Assert.AreEqual(2, t.NrOfGenericExecutions[typeof(string)]);
            t.DoubleTheNumber<string>(2);
            Assert.AreEqual(2, t.NrOfGenericExecutions.Count);
            Assert.AreEqual(2, t.NrOfGenericExecutions[typeof(int)]);
            Assert.AreEqual(2, t.NrOfGenericExecutions[typeof(string)]);
        }
Now running this test will fail. Why is that? well, since we are having an instance attribute, i guess the attribute is the same for all methods of DoubleTheNumber<T>. So all we need to do here, is save the results in a dictionary, that has the method definition as its base. The changes are minimal. We change our methodCache private field to
        private IDictionary<MethodBase, object> methodCache = new Dictionary<MethodBase, object>();
of course we need to initialize this field for new instances, so the implementaion of IInstanceScopedAspect changes with the initialization of this field:
        void IInstanceScopedAspect.RuntimeInitializeInstance()
        {
            this.methodCache = new Dictionary<MethodBase, object>();
        }
Then we change our onEntry and onExit methods, to use the dictionary value according to the exact method definition, instead of just a simple object. But everything else stays the same
        public override void OnEntry(PostSharp.Aspects.MethodExecutionArgs args)
        {

            if (!this.methodCache.ContainsKey(args.Method))
            {
                this.methodCache.Add(args.Method, null);
            }

            //remaining initialization
            [...]
            object resultDictionary = this.methodCache[args.Method];
            [...]

        }
And voila, our tests pass. Great. Generic arguments nailed.
Now what about null value arguments?

null values as arguments

So do we support null values? Lets see in a test. We are using here our TestClass2 from the 2nd post.
        [TestMethod]
        public void TestNullArguments()
        {
            var t = new TestClass2();
            var ta = new TestClass();
            Assert.AreEqual(0, t.NrOfExecutions);
            t.DoSomething(ta);
            Assert.AreEqual(1, t.NrOfExecutions);
            t.DoSomething(ta);
            Assert.AreEqual(1, t.NrOfExecutions);
            t.DoSomething(null);
            Assert.AreEqual(2, t.NrOfExecutions);
            t.DoSomething(null);
            Assert.AreEqual(2, t.NrOfExecutions);
            t.DoSomething(ta);
            Assert.AreEqual(2, t.NrOfExecutions);
            
        }
Well, this test fails miserably. The ConditionalWeakTable will throw an exception, if you try to add a null key. Actually if you think about it, it kinda misses the point. The structure is there, so if the key object is garbage collected, so the whole entry will disappear. Now since null cannot be garbage collected...
So it seems that we will need an additional way of storing values for nulls. But is that so difficult? We already differentiate arguments that can be null (IsValueType cannot be null), so lets just store the values for atguments like that in a bit extended structure. We create a wrapper around ConditionalWeakTable. I dont want to implement any special methods, and i will just expose the fields directly, but you could even try to implement IDictionary, or just pull some of the methods from the internal field to the object itself, its up to you. I just did:
    public class WeakNullableKeyValueCollection<TKey, TValue>
        where TKey: class
        where TValue : class
    {
        /// <summary>
        /// the table for the keys
        /// </summary>
        public ConditionalWeakTable<TKey, TValue> ConditionalWeakTable { get; private set; }

        private TValue _ValueForNull;
        /// <summary>
        /// the value for null
        /// </summary>
        public TValue ValueForNull { get { return _ValueForNull; } set { _ValueForNull = value; IsNullValueSet = true; } }

        /// <summary>
        /// whether the null value is set
        /// </summary>
        public bool IsNullValueSet { get; private set; }


        public WeakNullableKeyValueCollection()
        {
            ConditionalWeakTable = new ConditionalWeakTable<TKey, TValue>();
        }
    }
Using this class instead of simply a ConditionalWeakTable class should do the trick:
        public override void OnEntry(PostSharp.Aspects.MethodExecutionArgs args)
        {
            //existing code
            [...]
                        if (!types[i].IsValueType)
                        {
                            if (arg == null)
                            {
                                if ((resultDictionary as WeakNullableKeyValueCollection<object, object>).IsNullValueSet)
                                {
                                    resultDictionary = (resultDictionary as WeakNullableKeyValueCollection<object, object>).ValueForNull;
                                }
                                else
                                {
                                    isMissingKey = true;
                                    break;
                                }
                            }
                            else
                            {
                                if ((resultDictionary as WeakNullableKeyValueCollection<object, object>).ConditionalWeakTable.TryGetValue(arg, out result))
                                {
                                    resultDictionary = result;
                                }
                                else
                                {
                                    isMissingKey = true;
                                    break;
                                }

                            }
                        }
            //existing code
            [...]
        }

        public override void OnExit(PostSharp.Aspects.MethodExecutionArgs args)
        {
            //existing code
            [...]
                    if (!types[i].IsValueType)
                    {
                        if (getter() == null)
                        {
                            setter(new WeakNullableKeyValueCollection<object, object>());
                        }
                        var currenttable = getter();

                        getter = () =>
                        {
                            if (arg == null)
                            {
                                return (currenttable as WeakNullableKeyValueCollection<object, object>).ValueForNull;
                            }
                            else
                            {
                                object temp;
                                (currenttable as WeakNullableKeyValueCollection<object, object>).ConditionalWeakTable.TryGetValue(arg, out temp);
                                return temp;
                            }
                        };
                        setter = o =>
                        {
                            if (arg == null)
                            {
                                (currenttable as WeakNullableKeyValueCollection<object, object>).ValueForNull = o;
                            }
                            else
                            {
                                (currenttable as WeakNullableKeyValueCollection<object, object>).ConditionalWeakTable.Add(arg, o);
                            }
                        };

                    }
            //existing code
            [...]
        }
Now run the unit tests, and we get them all passed. Cool null support added.

Generic classes

Now what about generic classes? Lets write a small testclass,
        public class TestClass<T>
        {
            public static int StaticNrOfExecutions = 0;
            [CacheResultAttribute.CacheResult(CacheLocations.Static)]
            public static int StaticDoubleTheNumber(int i)
            {
                StaticNrOfExecutions++;
                return i * 2;
            }

            public int CallStaticDoubleTheNumber(int i)
            {
                return StaticDoubleTheNumber(i);
            }


            public int NrOfExecutions = 0;
            [CacheResultAttribute.CacheResult(CacheLocations.Static)]
            public int DoubleTheNumber(int i)
            {
                NrOfExecutions++;
                return i * 2;
            }
        }
and a test for it:
        [TestMethod]
        public void TestGenericClasses()
        {
            var ti = new TestClass<int>();
            var ts = new TestClass<string>();
            Assert.AreEqual(0, ti.NrOfExecutions);
            Assert.AreEqual(0, TestClass<int>.StaticNrOfExecutions);
            Assert.AreEqual(0, ts.NrOfExecutions);
            Assert.AreEqual(0, TestClass<string>.StaticNrOfExecutions);

            ti.DoubleTheNumber(1);
            Assert.AreEqual(1, ti.NrOfExecutions);
            Assert.AreEqual(0, TestClass<int>.StaticNrOfExecutions);
            Assert.AreEqual(0, ts.NrOfExecutions);
            Assert.AreEqual(0, TestClass<string>.StaticNrOfExecutions);

            ti.CallStaticDoubleTheNumber(1);
            Assert.AreEqual(1, ti.NrOfExecutions);
            Assert.AreEqual(1, TestClass<int>.StaticNrOfExecutions);
            Assert.AreEqual(0, ts.NrOfExecutions);
            Assert.AreEqual(0, TestClass<string>.StaticNrOfExecutions);

            TestClass<int>.StaticDoubleTheNumber(1);
            Assert.AreEqual(1, ti.NrOfExecutions);
            Assert.AreEqual(1, TestClass<int>.StaticNrOfExecutions);
            Assert.AreEqual(0, ts.NrOfExecutions);
            Assert.AreEqual(0, TestClass<string>.StaticNrOfExecutions);

            TestClass<string>.StaticDoubleTheNumber(1);
            Assert.AreEqual(1, ti.NrOfExecutions);
            Assert.AreEqual(1, TestClass<int>.StaticNrOfExecutions);
            Assert.AreEqual(0, ts.NrOfExecutions);
            Assert.AreEqual(1, TestClass<string>.StaticNrOfExecutions);

            ts.CallStaticDoubleTheNumber(1);
            Assert.AreEqual(1, ti.NrOfExecutions);
            Assert.AreEqual(1, TestClass<int>.StaticNrOfExecutions);
            Assert.AreEqual(0, ts.NrOfExecutions);
            Assert.AreEqual(1, TestClass<string>.StaticNrOfExecutions);

            ts.DoubleTheNumber(1);
            Assert.AreEqual(1, ti.NrOfExecutions);
            Assert.AreEqual(1, TestClass<int>.StaticNrOfExecutions);
            Assert.AreEqual(1, ts.NrOfExecutions);
            Assert.AreEqual(1, TestClass<string>.StaticNrOfExecutions);
        }
Well running the test passes, so it seems that somewhere along the way we already implemented this feature.
Actually the reason why we dont get collisions, is that the MethodBase, that we use as a key for method return value caching, is different between 2 generic classes with different Generic parameter. So we dont run into any collision.

Well, this concludes the functionality we set off for. The source code can be downloaded here
Please dont forget about the limitations mentioned in the second post, that were not addressed in this one.
If you have any comments or notes, please feel free to add them.
Have a nice day!

Monday, February 18, 2013

Memorizing method results with PostSharp (part 2)

This post is the continuation of the previous post Memorizing method results with PostSharp (part 1).
In the previous post we created an attribute, that allowed caching method results. The caching was implemented with a static dictionary object, that stored a reference to the objects and the method atguments as well. That solution is harly a good one for a production environment, as it might lead to memory leaks, the dictionary just growing and having more n more elements, withou allowing the GC to collect the objects that we dont use.

Reference to the object instance

Help from PostSharp

Now PostSharp has a solution for the first problem of our, having a reference to the object itself.
In the PostSharp documentation you can read about the lifecycle and scope of the postsharp aspects. The instance-scoped aspects seem to be exactly what we need. It creates one instance ofthe attribute for an instance of the class it is used on. In addition it automatically has an instance for static methods.
The documentation states that we only need to implement the IInstanceScopedAspect interface, and the magic will happen.
This interface has 2 methods:
object IInstanceScopedAspect.CreateInstance(AdviceArgs adviceArgs)
void IInstanceScopedAspect.RuntimeInitializeInstance()
The first one is called on to create the aspect when a new instance of the class is created. The second one is called on the newly created aspect.
We will simply copy over all the data from our aspect, and that should do. Since we are creating an instance aspect, we should be able to remove the first 2 keys from the dictionary, and just use the arguments as keys, and the return value as value.

Unit tests

Lets start first extending our test class with a static and a second instance method, and then write some tests for static, and 2 different instance methods.
        public class TestClass
        {
            public static int StaticNrOfExecutions = 0;
            [CacheResultAttribute.CacheResult(CacheLocations.Static)]
            public static int StaticDoubleTheNumber(int i)
            {
                StaticNrOfExecutions++;
                return i * 2;
            }

            public int CallStaticDoubleTheNumber(int i)
            {
                return StaticDoubleTheNumber(i);
            }
            
            public int NrOfExecutions = 0;
            public int NrOfExecutions2 = 0;
            [CacheResultAttribute.CacheResult(CacheLocations.Static)]
            public int DoubleTheNumber(int i)
            {
                NrOfExecutions++;
                return i * 2;
            }
            [CacheResultAttribute.CacheResult(CacheLocations.Static)]
            public int DoubleTheNumber2(int i)
            {
                NrOfExecutions2++;
                return i * 2;
            }
        }
So now having this class, lets write a test for static methods:
        [TestMethod]
        public void TestStaticResultCached()
        {
            var sa = new TestClass();
            var sb = new TestClass();
            Assert.AreEqual(0, TestClass.StaticNrOfExecutions);
            TestClass.StaticDoubleTheNumber(1);
            Assert.AreEqual(1, TestClass.StaticNrOfExecutions);
            TestClass.StaticDoubleTheNumber(2);
            Assert.AreEqual(2, TestClass.StaticNrOfExecutions);
            TestClass.StaticDoubleTheNumber(1);
            Assert.AreEqual(2, TestClass.StaticNrOfExecutions);
            sa.CallStaticDoubleTheNumber(4);
            Assert.AreEqual(3, TestClass.StaticNrOfExecutions);
            sa.CallStaticDoubleTheNumber(2);
            Assert.AreEqual(3, TestClass.StaticNrOfExecutions);
            sb.CallStaticDoubleTheNumber(4);
            Assert.AreEqual(3, TestClass.StaticNrOfExecutions);
        }
One for the 2 instance methods
[TestMethod]
        public void MethodsCachedSeparately()
        {
            var sa = new TestClass();
            Assert.AreEqual(0, sa.NrOfExecutions);
            Assert.AreEqual(0, sa.NrOfExecutions2);
            sa.DoubleTheNumber(1);
            Assert.AreEqual(1, sa.NrOfExecutions);
            Assert.AreEqual(0, sa.NrOfExecutions2);
            sa.DoubleTheNumber(2);
            Assert.AreEqual(2, sa.NrOfExecutions);
            Assert.AreEqual(0, sa.NrOfExecutions2);
            sa.DoubleTheNumber(1);
            Assert.AreEqual(2, sa.NrOfExecutions);
            Assert.AreEqual(0, sa.NrOfExecutions2);
            sa.DoubleTheNumber2(1);
            Assert.AreEqual(2, sa.NrOfExecutions);
            Assert.AreEqual(1, sa.NrOfExecutions2);
            sa.DoubleTheNumber2(2);
            Assert.AreEqual(2, sa.NrOfExecutions);
            Assert.AreEqual(2, sa.NrOfExecutions2);
            sa.DoubleTheNumber2(1);
            Assert.AreEqual(2, sa.NrOfExecutions);
            Assert.AreEqual(2, sa.NrOfExecutions2);

        }
and then a test for the same method on 2 objects.
[TestMethod]
        public void TestDifferentObjectsCached()
        {
            var sa = new TestClass();
            var sb = new TestClass();
            Assert.AreEqual(0, sa.NrOfExecutions);
            Assert.AreEqual(0, sb.NrOfExecutions);
            sa.DoubleTheNumber(1);
            Assert.AreEqual(1, sa.NrOfExecutions);
            Assert.AreEqual(0, sb.NrOfExecutions);
            sa.DoubleTheNumber(2);
            Assert.AreEqual(2, sa.NrOfExecutions);
            Assert.AreEqual(0, sb.NrOfExecutions);
            sa.DoubleTheNumber(1);
            Assert.AreEqual(2, sa.NrOfExecutions);
            Assert.AreEqual(0, sb.NrOfExecutions);
            sb.DoubleTheNumber(1);
            Assert.AreEqual(2, sa.NrOfExecutions);
            Assert.AreEqual(1, sb.NrOfExecutions);
        }
We trust PostSharp, that the instance attribute will go away with the object itself, so we dont do a test for that

The implementation

With the implementation we dont have a lot to change. We implement the required interface, and change our dictionary to only hold reference to the arguments.
    public class CacheResultAttribute : PostSharp.Aspects.OnMethodBoundaryAspect, IInstanceScopedAspect
    {
        private IDictionary methodCache = new Dictionary();

        // previous code
        [...]

        #region IInstanceScopedAspect implementation
        object IInstanceScopedAspect.CreateInstance(AdviceArgs adviceArgs)
        {
            var result = this.MemberwiseClone() as CacheResultAttribute;
            result.methodCache = new Dictionary();
            return result;
        }

        void IInstanceScopedAspect.RuntimeInitializeInstance()
        {
        }
        #endregion IInstanceScopedAspect implementation
    }
As you can see, we totally removed the necessity of storing anything static. Since the aspect lives as long as the object, and for static methods statically, no more implementation is required.
Also with this our methods get shorter as well, since we need less dictionary key lookups.
        public override void OnEntry(PostSharp.Aspects.MethodExecutionArgs args)
        {
            if (this.cacheLocation != CacheLocations.Static)
                throw new NotImplementedException("Only static cache location is implemented for method return cache");
            var item = args.Instance ?? args.Method.ReflectedType;
            object argsKey = args.Arguments.Count == 0 ? string.Empty : tupleCreator(args.Method.GetParameters().Select(x => x.ParameterType).ToArray(), args.Arguments.ToArray());
            if (methodCache.ContainsKey(argsKey))
            {
                args.ReturnValue = methodCache[argsKey];
                args.FlowBehavior = PostSharp.Aspects.FlowBehavior.Return;
            }

            base.OnEntry(args);
        }

        public override void OnExit(PostSharp.Aspects.MethodExecutionArgs args)
        {
            if (this.cacheLocation != CacheLocations.Static)
                throw new NotImplementedException("Only static cache location is implemented for method return cache");
            var item = args.Instance ?? args.Method.ReflectedType;
            object argsKey = args.Arguments.Count == 0 ? string.Empty : tupleCreator(args.Method.GetParameters().Select(x => x.ParameterType).ToArray(), args.Arguments.ToArray());
            methodCache.Add(argsKey, args.ReturnValue);

            base.OnExit(args);
        }
Great, now running our unit tests, we get them all passed. Amazing! First problem resolved, thank you PostSharp.

Reference to arguments

Now the other problem we are having, is hard references to arguments. Lets create a new TestClass2, with a method that takes TestClass as an argument.
        public class TestClass2
        {
            public int NrOfExecutions = 0;
            [CacheResultAttribute.CacheResult(CacheLocations.Static)]
            public int DoSomething(TestClass tc)
            {
                NrOfExecutions++;
                return 4;
            }
        }
Now we want to make sure, that the caching does not hold a reference to our argument object. Lets write a test:
        [TestMethod]
        public void TestReferenceToArgumentsNotRemembered()
        {
            var ta = new TestClass();
            var t2 = new TestClass2();
            Assert.AreEqual(0, t2.NrOfExecutions);
            t2.DoSomething(ta);
            Assert.AreEqual(1, t2.NrOfExecutions);
            t2.DoSomething(ta);
            Assert.AreEqual(1, t2.NrOfExecutions);
            var wr = new WeakReference(ta);
            ta = null;
            GC.Collect();
            Assert.AreEqual(null, wr.Target);
        }
Here we simply create the instance ta, then use it to call the method. Then create a WeakReference to it. WeakReference is a builtin framework class, that allows an optional reference to an object, as long as the object has a reference somewhere else. But WeakReference does not block the garbage collection on that object.
We remove our reference to ta, and we force a garbage collection, which should normally result in the ta to be garbage collected. If you run this test, you will see it fail, as the Dictionary inside our attribute has a strong reference to ta, and prevent the GC from collecting it.

The help again comes from outside. The framework provides a class with similar functionality as WeakReference, just in a table like form. It is called ConditionalWeakTable.
Lets try to use this class instead of the dictionary we had.
Well, in reality, we were using Tuple<> to wrap the arguments, and it is actually the tuple that has the reference to the object, and the dictionary has a reference to the tuple. We will rewrite our dictionary to use the keys directly, and use the ConditionalWeakTable class. We will need to handle the argumentless case separately, but otherwise we will store a ConditionalWeakTable for each argument, so an arg list of (string, int, Type) will be ConditionalWeakTable<object, ConditionalWeakTable<object, ConditionalWeakTable<object, object>>>. As you can see we store everything as objects. This might cause a problem, but lets see.
        /// <summary>
        /// the result in case there are no arguments
        /// </summary>
        private object noArgResult;
        /// <summary>
        /// whether the no argument result has been calculated
        /// this is required, as the result can be null
        /// </summary>
        private bool isNoArgResultInitialized = false;

        [NonSerialized]
        private ConditionalWeakTable<object, object> methodCache = new ConditionalWeakTable<object, object>();

        //previous code
        [...]

        public override void OnEntry(PostSharp.Aspects.MethodExecutionArgs args)
        {
            if (this.cacheLocation != CacheLocations.Static)
                throw new NotImplementedException("Only static cache location is implemented for method return cache");

            if (args.Arguments.Count == 0)
            {
                if (isNoArgResultInitialized)
                {
                    args.ReturnValue = noArgResult;
                    args.FlowBehavior = FlowBehavior.Return;
                }
            }
            else
            {
                var types = args.Method.GetParameters().Select(x => x.ParameterType).ToList();
                bool isMissingKey = false;
                var resultDictionary = this.methodCache;
                for (int i = 0; i < args.Arguments.Count - 1; i++)
                {
                    var arg = args.Arguments[i];
                    object result;
                    if (resultDictionary.TryGetValue(arg, out result))
                    {
                        resultDictionary = result as ConditionalWeakTable<object, object>;
                    }
                    else
                    {
                        isMissingKey = true;
                        break;
                    }
                }
                object finalresult;
                if (!isMissingKey && resultDictionary.TryGetValue(args.Arguments.Last(), out finalresult))
                {
                    args.ReturnValue = finalresult;
                    args.FlowBehavior = FlowBehavior.Return;
                }

            }

            base.OnEntry(args);
        }

        public override void OnExit(PostSharp.Aspects.MethodExecutionArgs args)
        {
            if (this.cacheLocation != CacheLocations.Static)
                throw new NotImplementedException("Only static cache location is implemented for method return cache");

            if (args.Arguments.Count == 0)
            {
                isNoArgResultInitialized = true;
                noArgResult = args.ReturnValue;
            }
            else
            {
                var table = methodCache;
                for (int i = 0; i < args.Arguments.Count - 1; i++)
                {
                    var arg = args.Arguments[i];
                    object nextTable;
                    if (table.TryGetValue(arg, out nextTable))
                    {
                        table = nextTable as ConditionalWeakTable<object, object>;
                    }
                    else
                    {
                        var nextTable2 = new ConditionalWeakTable<object, object>();
                        table.Add(arg, nextTable2);
                        table = nextTable2;
                    }
                }
                table.Add(args.Arguments.Last(), args.ReturnValue);
            }

            base.OnExit(args);
        }
Well, looks a bit more complicated, but it is not that much. We simply go through the arguments, and if another argument found, we add a new ConditionalWeaktable.

Cool, lets run our unit tests, and .... BANG. The last one passes, but the previoud ones all fail. Whaaaaat? What just happened?
To understand what happened, we need to have a look at the ConditionalWeakTable documentation, and see the note at the middle of the page: "ConditionalWeakTable class supports one attached value per managed object. Two keys are equal if passing them to the Object.ReferenceEquals method returns true". Hmm, ok but we didnt really passed in objects in our failed tests. We passed in and integer of 3, why did it fail?
Well the answer is boxing. C# allows boxing of any value into an obejct. This allows you to add an integer to an object list, like
var list = new List<object>();
int i = 8;
list.Add(i);
Here the compiler is actually boxing the value 8 to an object, and add it to the list.
Combining this with the fact, that the ConditionalWeakTable checks for object.ReferenceEquals, we now understand what happened. We have 2 different object, which both are the same integer, but do not say true to object.ReferenceEquals.
A quick test confirms it.

        [TestMethod]
        public void Test()
        {
            int i = 6;
            object a = i;
            object b = i;
            Assert.IsFalse(Object.ReferenceEquals(a, b));
        }
So now what? What we really want to, is not to have memory used up when it will not be required anymore. How do we know it is not required anymore? Well for an object it is easy, if no references exist to it, we dont need it. But with a number? or a string? Is it even our responsability to care about this?
Well, i would argue about this. The functionality we are trying to achieve, is to cache method results. Now if you cache a result for the number 3, do you want to hold on to that result forever? Well, with our intention, of having the cache static, it is probably what you want. So for simple types, it is probably not an issue to have them stored in the dictionary directly. The problem might arise, when you start to use quite complex tructs as arguments, as those will not be garbage collected ever. This is actually a limitation, that we will just accept, as for the time being, it is good enough.
So for the not reference types, we will just use a dictionary, and we will use the ConditionalWeakReference for the rest.
Implementing this gives us the code:

        public override void OnEntry(PostSharp.Aspects.MethodExecutionArgs args)
        {
            if (this.cacheLocation != CacheLocations.Static)
                throw new NotImplementedException("Only static cache location is implemented for method return cache");

            if (!this.methodCache.ContainsKey(args.Method))
            {
                this.methodCache.Add(args.Method, null);
            }

            if (args.Arguments.Count == 0)
            {
                if (isNoArgResultInitialized)
                {
                    args.ReturnValue = methodCache[args.Method];
                    args.FlowBehavior = FlowBehavior.Return;
                }
            }
            else
            {
                var types = args.Method.GetParameters().Select(x => x.ParameterType).ToList();
                bool isMissingKey = false;
                object resultDictionary = this.methodCache;
                if (resultDictionary != null)
                {
                    for (int i = 0; i < args.Arguments.Count; i++)
                    {
                        object result;
                        var arg = args.Arguments[i];
                        if (!types[i].IsValueType)
                        {
                            if ((resultDictionary as ConditionalWeakTable<object, object>).TryGetValue(arg, out result))
                            {
                                resultDictionary = result;
                            }
                            else
                            {
                                isMissingKey = true;
                                break;
                            }
                        }
                        else
                        {
                            if ((resultDictionary as IDictionary).Contains(arg))
                            {
                                resultDictionary = (resultDictionary as IDictionary)[arg];
                            }
                            else
                            {
                                isMissingKey = true;
                                break;
                            }
                        }

                    }
                    if (!isMissingKey)
                    {
                        args.ReturnValue = resultDictionary;
                        args.FlowBehavior = FlowBehavior.Return;
                    }
                }
            }

            base.OnEntry(args);
        }

        public override void OnExit(PostSharp.Aspects.MethodExecutionArgs args)
        {
            if (this.cacheLocation != CacheLocations.Static)
                throw new NotImplementedException("Only static cache location is implemented for method return cache");

            if (args.Arguments.Count == 0)
            {
                isNoArgResultInitialized = true;
                methodCache[args.Method] = args.ReturnValue;
            }
            else
            {
                var types = args.Method.GetParameters().Select(x => x.ParameterType).ToList();
                Func<object> getter = () => methodCache;
                Action<object> setter = o => methodCache = o;
                object table = getter();
                for (int i = 0; i < args.Arguments.Count; i++)
                {
                    var arg = args.Arguments[i];
                    if (!types[i].IsValueType)
                    {
                        if (getter() == null)
                        {
                            setter(new ConditionalWeakTable<object, object>());
                        }
                        var currenttable = getter();

                        getter = () =>
                        {
                            object temp;
                            (currenttable as ConditionalWeakTable<object, object>).TryGetValue(arg, out temp);
                            return temp;
                        };
                        setter = o =>
                        {
                            (currenttable as ConditionalWeakTable<object, object>).Add(arg, o);
                        };
                    }
                    else
                    {
                        if (getter() == null)
                        {
                            setter(new Hashtable());
                        }
                        var currenttable = getter();
                        getter = () => (currenttable as Hashtable)[arg];
                        setter = o =>
                        {
                            (currenttable as Hashtable).Add(arg, o);
                        };
                    }

                }
                setter(args.ReturnValue);
            }

            base.OnExit(args);
        }


Now when running our unit tests, we have all of them passed. So we managed what we set out for.

Limitations

Well, as we saw above, we might have some memory leak when using large value type objects, like complex structs as parameters. This is a VERY IMPORTANT limitaion, and you do have to be aware of this, when you use this attribute extensively.

Generics

Well what bout generic methods and generic types? Well, the above implementation does not really support them, so you have to keep that in mind. But the implementaion can be extended to support those concepts too.

Null values

Im not sure, but i think ConditionalWeakTable does not support null values, so an extra implementation is required to handle those too.

Another limitaiton i can think of comes up, when you have a reference type parameter after a value type parameter in the arguments list order. For example having [CacheResult] int Method(int number, TestClass testClass) Now when we first call this with 0 and ta (where ta is a TestClass instance), then the the result is cached in a HashTable, for 0 it has a value of a ConditionalWeakTable<object, object>, that has a weak reference to ta.
Now when ta goes out of scope, the ConditionalWeakTable will loose the reference to it, but the HashTable will still have in it an empty ConditionalWeakTable object. This is obviously some memory loss, even though it is less then having a reference to the class itself, if you use this attribute extensively, you should consider addressing that issue as well.

If you are curious how to remove some of these limitations, please follow onto the next post.

Sunday, February 17, 2013

Memorizing method results with PostSharp (part 1)

Now regularly when I require a new functionality of some sort, and I try to think of the problem in general, outside of the current problem context, I run into the need of caching the result of a specific method, in order to speed up the code if someone (even me) would decide to use it extensively.

The problem

I mean, how may times did you see a code like this in your life:
    public class SomeClass
    {
        private static Dictionary<Tuple<Argument1Type, Argument2Type>, MyResultType> someExpensiveMethodsomeMethodResultCache = new Dictionary<Tuple<Argument1Type, Argument2Type>, MyResultType>();

        public MyResultType SomeExpensiveMethod(Argument1Type arg1, Argument2Type arg2)
        {
            var key = Tuple.Create(arg1, arg2);
            if (!someExpensiveMethodsomeMethodResultCache.ContainsKey(key))
            {
                MyResultType result = null;
                // fill result;
                [...]
                someExpensiveMethodsomeMethodResultCache.Add(key, result);
            }
            return someExpensiveMethodsomeMethodResultCache[key];
        }

    }

Now, obviously you could cache the result not statically, but in an object cache, or any other way, but the bottom line is, it seems to be quite a lot of code and responsability for a functionality, that i would much prefer just to 'attach' to the method, right?

The goal

As we all know, there is a mechanism in the .Net framework, to 'attach' functionality to the code elements, and they are called attributes.
Now what i'd really like to do, is say simply:
    public class SomeClass
    {
        [CacheResult(CacheLocations.Static)]
        public MyResultType SomeExpensiveMethod(Argument1Type arg1, Argument2Type arg2)
        {
            MyResultType result = null;
            // fill result;
            [...]
            return result;
        }
    }
Well, i do think so, so lets try to create an attribute like that.

The solution

Now one major thing to understand here is, that we want to interfere with the normal execution of the program, and hijack it. Basically, if we already have the cached verion, we just return that, and do not execute the code again.
Here is where a new concept comes in called aspects.
If you are not familiar with the term aspects, you should read more about it on Wikipedia here, or here, or just google for it!! It is a truly exceptional field, and can do really amazing things, for example what we are doing here :)
Now a very nice and tool for such things is PostSharp, which has a free license for 1 developer. It is a limited edition, but it gets you enough to accomplish what we want here.

A few words on PostSharp

Well, it is useful to understand how PostSharp works, and what can it do for you. PostSharp is basically a library, that gets into the build pipeline. Once your code is compiled to IL, PostSharp will go through the compiled code, and look for attributes, that derive from their attributes. It will then change the method implementations according to what the attributes should do.
It can intercept method calls and property calls, and a lot more. You can read about the capabilities on the SharpCrafter's website

Ok, so now that we have a general idea of the tools we can use, lets jump into the unit tests.

Unit Tests

Generally when creating a unit test for a functionality, you also need to come up with the interfaces and classes that you will use to interact with the functionality.
So first we will need an attribute, lets say CacheResultAttribute.
    [AttributeUsage(AttributeTargets.Method, AllowMultiple=false)]
    public class CacheResultAttribute : Attribute
    {
    }
Now for the sake of future requirements, when you will not want to cahce the results statically, but in another way, lets allow the specification of the cache location, with an enum called CacheLocations.
    public enum CacheLocations
    {
        Static,
    }
and also add this as a contructor parameter for the attribute:
        private readonly CacheLocations cacheLocation;
        public CacheResultAttribute(CacheLocations location)
        {
            cacheLocation = location;
        }
Now we can create a test class that will count the method execution, and return the amount of time executed.
    [TestClass]
    public class CacheResultAttributeTest
    {
        public class TestClass
        {
            public int NrOfExecutions = 0;
            [CacheResultAttribute.CacheResult(CacheLocations.Static)]
            public int DoubleTheNumber(int i)
            {
                NrOfExecutions++;
                return i * 2;
            }
        }
    }
As you can see, I already decorated the DoubleTheNumber method with the attribute i just created above. I know public fields are not a good idea, but for the sake of keeping the test code short, its not a big deal.
Now in our tests we simply want to see if the DoubleTheNumber method is not called for parameters that were already passed in. The test method if fairly simple,
        [TestMethod]
        public void TestResultCached()
        {
            var subject = new TestClass();
            Assert.AreEqual(0, subject.NrOfExecutions);
            var result = subject.DoubleTheNumber(3);
            // we dont care if the method is actually ok, so this is not needed
            //Assert.AreEqual(6, result);

            //but we want to check if the execution count went up
            Assert.AreEqual(1, subject.NrOfExecutions);

            // now lets see if it is called with other parameter
            result = subject.DoubleTheNumber(2);
            //again check that execution nr went up
            Assert.AreEqual(2, subject.NrOfExecutions);

            //now get the result for 3 again
            result = subject.DoubleTheNumber(3);
            // and make sure we didnt execute the method
            Assert.AreEqual(2, subject.NrOfExecutions);
        }
it doesnt even need explanation.

Great, now we have a test, lets run it, and we have the first test result : "Result Message: Assert.AreEqual failed. Expected:<2>. Actual:<3>."
Of course since we didnt implement anything yet, the 3rd call just calls the function, and it is executed the 3rd time as well. So lets have a look at how this will be done.

The implementation

The first thing you have to do is get and install PostSharp. Please follow the steps on the website if you dont have it yet.
Oncc you have postsharp, and reference to the dlls, we can start extending an attribute called OnMethodBoundaryAspect which is in the postsharp dlls. This aspect allows you to intercept metod calls. You have 3 methods you can override:
bool CompileTimeValidate(MethodBase method)
void OnEntry(MethodExecutionArgs args)
void OnExit(MethodExecutionArgs args)
The first one is executed, when PostSharp recompiles the code. This means that the code here is only executed at build time, it does not have an influence at runtime, so this can be as heavy as you want. Obviously you will have a longer compile time, but it will not be visible to the code users.
The second and third methods are called when entering and exiting a method.
The MethodExecutionArgs class contains all necessary data, that you will require at runtime to decide what to do. So lets jump in.
    [Serializable]  // required by PostSharp
    [AttributeUsage(AttributeTargets.Method, AllowMultiple=false)]
    public class CacheResultAttribute : PostSharp.Aspects.OnMethodBoundaryAspect
    {
        private readonly CacheLocations cacheLocation;
        public CacheResultAttribute(CacheLocations location)
        {
            cacheLocation = location;
        }

        public override bool CompileTimeValidate(System.Reflection.MethodBase method)
        {
            return base.CompileTimeValidate(method);
        }

        public override void OnEntry(PostSharp.Aspects.MethodExecutionArgs args)
        {
            base.OnEntry(args);
        }

        public override void OnExit(PostSharp.Aspects.MethodExecutionArgs args)
        {
            base.OnExit(args);
        }

    }
Now lets start with the first. We really dont need to do anything here, but just for the same of understanding, lets not allow applying this attribute to void methods and constructors. That would be misleading anyways. So lets check if we are a void method, and if so, throw the error.
        public override bool CompileTimeValidate(System.Reflection.MethodBase method)
        {
            if (method == null)
            {
                PostSharp.Extensibility.Message error = new PostSharp.Extensibility.Message(MessageLocation.Explicit("CacheResultAttribute", 25, 0), SeverityType.Error, "AOP0001", "Method is null", "#", "CacheResultAttribute.cs", null);
                MessageSource.MessageSink.Write(error);
                return false;
            }
            if(method.IsConstructor)
            {
                PostSharp.Extensibility.Message error = new PostSharp.Extensibility.Message(MessageLocation.Explicit("CacheResultAttribute", 25, 0), SeverityType.Error, "AOP0001", "Attribute cannot be applied to constructors", "#", "CacheResultAttribute.cs", null);
                MessageSource.MessageSink.Write(error);
                return false;
            }
            if (!(method is MethodInfo))
            {
                PostSharp.Extensibility.Message error = new PostSharp.Extensibility.Message(MessageLocation.Explicit("CacheResultAttribute", 25, 0), SeverityType.Error, "AOP0001", string.Format("Attribute cannot be applied to method {0} because it cannot be cast to a MethodInfo", method.Name), "#", "CacheResultAttribute.cs", null);
                MessageSource.MessageSink.Write(error);
                return false;
            }
            if((method as MethodInfo).ReturnType == typeof(void))
            {
                PostSharp.Extensibility.Message error = new PostSharp.Extensibility.Message(MessageLocation.Explicit("CacheResultAttribute", 25, 0), SeverityType.Error, "AOP0001", string.Format("Attribute cannot be applied to method {0} because it has a void return type", method.Name), "#", "CacheResultAttribute.cs", null);
                MessageSource.MessageSink.Write(error);
                return false;
            }
            return base.CompileTimeValidate(method);
        }

Really not much interesting here. So lets get to the next one, what do we want to do in the OnEntry and OnExit methods?
Well, postsharp allows you to hijack method execution, and do not actually run the method by setting the args.FlowBehavior = FlowBehavior.Return. This will be very handy, since we will do just that, when we find a cached value.
So we have all the arguments in the args.Arguments property, we have the methodInfo in the args.Method property, and the instance (if not static method) in the args.Instance property.
Now one convention we will use, is, that if we encounter a static method, then we will store it to the type on which it is defined. Otherwise we will store the results for the cached object. Since System.Type does not use our attribute (we are just creating it), we should have no problems about duplicate dictionary keys. So we will index our static cache by the following keys:
1) Object instance or type
2) MethodBase
3) arguments
With these assumptions we have the follwoing code:
        private static readonly IDictionary<object, IDictionary<MethodBase, IDictionary<object, object>>> staticMethodCache = new Dictionary<object, IDictionary<MethodBase, IDictionary<object, object>>>();

        public override void OnEntry(PostSharp.Aspects.MethodExecutionArgs args)
        {
            if (this.cacheLocation != CacheLocations.Static)
                throw new NotImplementedException("Only static cache location is implemented for method return cache");
            var item = args.Instance ?? args.Method.ReflectedType;
            if (staticMethodCache.ContainsKey(item)
                && staticMethodCache[item] != null
                && staticMethodCache[item].ContainsKey(args.Method)
                && staticMethodCache[item][args.Method] != null)
            {
                object argsKey = args.Arguments.Count == 0 ? string.Empty : tupleCreator(args.Method.GetParameters().Select(x => x.ParameterType).ToArray(), args.Arguments.ToArray());
                if (staticMethodCache[item][args.Method].ContainsKey(argsKey))
                {
                    args.ReturnValue = staticMethodCache[item][args.Method][argsKey];
                    args.FlowBehavior = PostSharp.Aspects.FlowBehavior.Return;
                }
            }
            base.OnEntry(args);
        }
We simply checked if we have the value in the dictionary, and if yes, then we set the return value, and break the flow.
We used here a helper method tupleCreator. This essentially creates a tuple from the arguments. The tuple is of type (for example for arguments (int, string, Type) it is Tuple<Type, Tuple<string, Tuple<int>>>).
        /// 
        /// creates a tuple of the arguments
        /// 
        /// 
        /// 
        /// 
        private object tupleCreator(Type[] types, object[] arguments)
        {
            if (types == null)
                throw new ArgumentNullException("types");
            if (arguments == null)
                throw new ArgumentNullException("arguments");
            if (types.Length == 0)
                throw new ArgumentOutOfRangeException("types", "The specified type list needs at least 1 type");
            if (types.Length != arguments.Length)
                throw new ArgumentException("The specified argument count does not equal the type count", "arguments");
            var tupleCreator1 = typeof(Tuple).GetMethods(BindingFlags.Static | BindingFlags.Public).Where(x => x.Name == "Create" && x.GetGenericArguments().Length == 1).Single();
            var tupleCreator2 = typeof(Tuple).GetMethods(BindingFlags.Static | BindingFlags.Public).Where(x => x.Name == "Create" && x.GetGenericArguments().Length == 2).Single();

            var result = tupleCreator1.MakeGenericMethod(types[types.Length -1]).Invoke(null, new object[]{arguments[types.Length -1]});
            for (int i = types.Length -2; i > -1; i++)
            {
                result = tupleCreator2.MakeGenericMethod(types[i], result.GetType()).Invoke(null, new object[] { arguments[i], result });
            }
            return result;
        }

The method is quite straightforward, it can be improved by statically caching the two method infos for getting the Tuple factory methods. It can be also decreased with the number of calls to Invoke, since there are Tuple methods with more arguments, but for now this will do.
Now for the exit, we basically need to store the result in the dictionary, and we are done.
        public override void OnExit(PostSharp.Aspects.MethodExecutionArgs args)
        {
            if (this.cacheLocation != CacheLocations.Static)
                throw new NotImplementedException("Only static cache location is implemented for method return cache");
            var item = args.Instance ?? args.Method.ReflectedType;
            if (!staticMethodCache.ContainsKey(item))
            {
                staticMethodCache.Add(item, new Dictionary<MethodBase, IDictionary<object, object>>());
            }
            if (!staticMethodCache[item].ContainsKey(args.Method))
            {
                staticMethodCache[item].Add(args.Method, new Dictionary<object, object>());
            }
            object argsKey = args.Arguments.Count == 0 ? string.Empty : tupleCreator(args.Method.GetParameters().Select(x => x.ParameterType).ToArray(), args.Arguments.ToArray());
            staticMethodCache[item][args.Method].Add(argsKey, args.ReturnValue);
            base.OnExit(args);
        }
We create the dictionaries, and we insert the value. This would by the way throw an exception, if we managed somehow to get here again, with the same parameters.

Well, now lets run our unit tests, and wohooo. It passes. So with just a couple of line of code, we managed to move the caching code from the actual place to an aspect, that now we can apply anywhere.

Room for improvement

As you can see, there is only 1 caching method implementaion, the one for static caching. I use this type of caching a lot, when i create type maps, and other methods, that operate mostly on types. You could implement other caching mechanisms, and extend the above.

Problems

DO NOT USE THIS CODE IN PRODUCTION!!!
Why do I say that?

Well, lets take a look at our dictionaries. The first key is the object itself. That means, that any time this aspect is hit, there will be a reference to that object created in your dictionary, which will prevent garbage collection for that object. This is a massive issue, if you want to use this attribute on instances. As long as you use it on static methods, there is no major memory loss, but with instance methods this will result in you using up more n more memory. We will discuss this issue in an upcoming post.

The same problem then with object instances, comes up with the arguments. The dictionaries will have references to those arguments as well, which will do the same harm as above. As above, this will be addressed in the next post.
Cheers for reading.

Continue reading on the second part of this post