Tuesday, April 26, 2011

Dynamically Invoke Generic Method In .NET


Today I cleaned up a lengthy method in our project which I think it's worthy of a post. Our project has a repository that caches all business objects, we called them entities and the entity list is long. The cache is loaded/refreshed when repository is first accessed or some entities have been changed in the back-end store. The method for loading entities is something like:
void LoadEntitiesCache()
{
    List<EntityType1> entity1List = DataStore.GetAllEntitiesByType<EntityType1>();
    SaveEntitiesToCache(Entity1List);
    List<EntityType2> entity2List = DataStore.GetAllEntitiesByType<EntityType2>();
    SaveEntitiesToCache(Entity2List);
    List<EntityType3> entity3List = DataStore.GetAllEntitiesByType<EntityType3>();
    SaveEntitiesToCache(Entit32List);
    ....
    List<EntityTypeN> entityNList = DataStore.GetAllEntitiesByType<EntityTypeN>();
    SaveEntitiesToCache(EntityNList);
}
The way of my clean-up is something like:
void LoadEntitiesCache()
{
    Type [] allEntityTypes = new Type [] { EntityType1, EntityType2, ..., EntityTypeN };
    foreach(Type entityType in allEntityTypes)
    {
        List<entityType> entityList = DataStore.GetAllEntitiesByType<entityType>();
        SaveEntitiesToCache(entityList);
    }
}
The logic is there but above code won't work because you can't invoke a generic method like that, instead reflection must be used for such usage. Following code example shows how to use reflection to invoke a generic method (entity/cache update logic is not included):
using System;
using System.Collections;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Reflection;
using System.Collections.ObjectModel;
using System.Collections.Concurrent;

public interface IEntity
{
    string Identifier { get; }
    string Type { get; }
}

public class Book : IEntity
{
    public string ISBN { get; set; }
    public string Category { get; set; }
    public string Identifier { get { return ISBN; } }
    public string Type { get { return Category; } }
}

public class CD : IEntity
{
    public string Name { get; set; }
    public string Category { get; set; }
    public string Identifier { get { return Name; } }
    public string Type { get { return Category; } }
}

public class KeyedEntityCollection<T> : KeyedCollection<string, T> where T : IEntity
{
    protected override string GetKeyForItem(T item)
    {
        return item.Identifier;
    }
}

public class Repository
{
    static ConcurrentDictionary<Type, object> _entitiesCache = new ConcurrentDictionary<Type, object>();

    void LoadEntitiesCache()
    {
        // Load all entities from back-end store and save them in _entitiesCache
        DataSource dataSource = new DataSource();

        Type[] entityTypes = new Type[] { typeof(Book), typeof(CD) };
        MethodInfo methodInfo = typeof(DataSource).GetMethod("GetEntitiesByType");
        foreach (var entityType in entityTypes)
        {
            object _cacheObject = null;
            if (_entitiesCache.TryGetValue(entityType, out _cacheObject) && _cacheObject != null)
                continue;

            try
            {
                // Invoke DataStore.GetEntitiesByType<T>() dynamically
                MethodInfo genericMethod = methodInfo.MakeGenericMethod(entityType);
                IEnumerable dataReturn = genericMethod.Invoke(dataSource, null) as IEnumerable;

                // Create EntityCollection dynamically
                Type[] types = new Type[] { entityType };
                Type entityCollectionType = typeof(KeyedEntityCollection<>);
                Type genericType = entityCollectionType.MakeGenericType(types);
                IList genericCollection = Activator.CreateInstance(genericType) as IList;

                if (dataReturn != null)
                {
                    foreach (var entity in dataReturn)
                    {
                        genericCollection.Add(entity);
                    }
                }
                _entitiesCache.AddOrUpdate(entityType, genericCollection, 
                    (type, existingValue) => { return genericCollection; });
            }
            catch (Exception ex)
            {
                // Log error
                throw;
            }
        }
    }

    public Repository()
    {
        LoadEntitiesCache();
    }

    public KeyedEntityCollection<T> GetEntities<T>() where T : IEntity
    {
        object cachedEntityCollection = null;
        _entitiesCache.TryGetValue(typeof(T), out cachedEntityCollection);
        return cachedEntityCollection as KeyedEntityCollection<T>;
    }

    public T GetEntityById<T>(string identifier) where T : IEntity
    {
        var entityCollection = GetEntities<T>();
        if (entityCollection != null && entityCollection.Contains(identifier))
        {
            return entityCollection[identifier];
        }
        return default(T);
    }
}

internal class DataSource
{
    // Get data from database, or xml files, or whatever data store
    public IEnumerable<T> GetEntitiesByType<T>() where T : IEntity
    {
        // Mock some test data here
        if (typeof(T) == typeof(Book))
        {
            List<Book> bookSource = new List<Book>();
            Enumerable.Range(1, 100).ToList().ForEach(
                i => bookSource.Add(new Book() { ISBN = i.ToString(), Category = "Computer" }));
            return bookSource as IEnumerable<T>;
        }
        else if (typeof(T) == typeof(CD))
        {
            List<CD> cdSource = new List<CD>();
            Enumerable.Range(1, 100).ToList().ForEach(
                i => cdSource.Add(new CD() { Name = i.ToString(), Category = "Music" }));
            return cdSource as IEnumerable<T>;
        }
        else return null;
    }
}

class Program
{
    static void Main(string[] args)
    {
        Repository repo = new Repository(); 
        var books = repo.GetEntities<Book>();           // Get all books
        var secondBook = repo.GetEntityById<Book>("2"); // Get the second book
        var cds = repo.GetEntities<CD>();               // Get all CDs
        var thirdCD = repo.GetEntityById<CD>("3");      // Get the third CD
        Console.Read();
    }
}