March 27, 2009

On reflecting extension methods in C#

Extension methods are a great way of adding functionality to existing classes. It almost makes C# similar to ruby/javascript, where none of the classes are "closed" - functionality can be added to them at any point of time.

For example, say we want the method ToTitleCase to be available to all String objects, and we define an extension method.
namespace MyExtensions
{
    public static class StringExtensions
    {
        public static String ToTitleCase(this String word)
        {
            return word[0].ToString().ToUpper() + word.Substring(1).ToLower();
        }
    }
}
Now, wherever we want to use the method ToTitleCase, we include the namespace MyExtensions, and the following becomes valid C# code :
        [TestMethod]
        public void TestCall()
        {
            Assert.AreEqual("Company", "COMPANY".ToTitleCase());
            Assert.AreEqual("Company", "company".ToTitleCase());
            Assert.AreEqual("Company", "Company".ToTitleCase());
        }
However, C# extension methods are simply syntactic sugar. Any extension method calls in source, such as the ones above, are transformed by the compiler into this :
        [TestMethod]
        public void TestCallClassic()
        {
            Assert.AreEqual("Company", StringExtensions.ToTitleCase("COMPANY"));
            Assert.AreEqual("Company", StringExtensions.ToTitleCase("company"));
            Assert.AreEqual("Company", StringExtensions.ToTitleCase("Company"));
        } 
The extension methods for any given type available to the compiler, but they are compiled as regular static function calls. Feels like cheating, I tell you.

Now what happens when we try to invoke an extension method via reflection?
        [TestMethod]
        public void TestMethodInfoInTargetClass()
        {
            // We won't find the method in String ...
            Assert.IsNull(typeof(String).GetMethod("ToTitleCase"));
        }

        [TestMethod]
        public void TestMethodInfoInDefinigClass()
        {
            // But we will find it in StringExtensions.
            Assert.IsNotNull(typeof(StringExtensions).GetMethod("ToTitleCase"));
        } 
This should be obvious - we aren't going to find them on the target type - we'll find them in the class where they are defined.

So, how does this affect us?

Say you're browsing through some source code, and you see a call like this - myObject.someMethod(). If you need to call the method someMethod() dynamically, you can't use the type of myObject to reflect it. Instead, you need to know if someMethod() is an extension method, and if it is, you need to reflect it off the class in which it is defined.

That solves the problem when we know which extension method we need to call. If we don't, and we want to know all the extension methods available for a given type, we can use the attribute ExtensionAttribute. This attribute indicates that a method is an extension method, or that a class or assembly contains extension methods. Given this, we can implement a function that returns all extension methods defined for a given type.
        IEnumerable <MethodInfo> GetAllExtensionMethods(Type targetType)
        {
            return
                from assembly in AppDomain.CurrentDomain.GetAssemblies()
                where assembly.IsDefined(typeof(ExtensionAttribute), false)
                    from type in assembly.GetTypes()
                    where type.IsDefined(typeof(ExtensionAttribute), false)
                    where type.IsSealed && !type.IsGenericType && !type.IsNested
                        from method in type.GetMethods(BindingFlags.Static | BindingFlags.Public | BindingFlags.NonPublic)
                        // this filters extension methods
                        where method.IsDefined(typeof(ExtensionAttribute), false)
                        where
                            // is it defined on me?
                            targetType == method.GetParameters()[0].ParameterType ||

                            // or on any of my interfaces?
                            targetType.GetInterfaces().Contains(method.GetParameters()[0].ParameterType)

                            // or on any of my base types?
                            targetType.IsSubclassOf(method.GetParameters()[0].ParameterType)
                        select method;
        }
The above method was inspired by Jon Skeet's answer on Stack Overflow. It simply improves on it by detecting for interfaces and base types, and looking in all assemblies in the current AppDomain.

Update: Fixed some code up there.