NHibernate SchemaValidator reimplemented.

Some time ago a new utility class was introduced to NHibernate. It’s called SchemaValidator and it allows you to detect inconsistencies between mappings and database schema. Here you can find a short sample code showing how it works.

The class has got only one useful method called Validate(). It analyzes and compares database metadata with mappings. It can detect missing tables, missing columns, improper types in mappings, etc. It throws HibernateException immediately after it encounters such a problem. However, this is a serious disadvantage (at least for me), because I would like to know all issues related to my mappings and I don’t want to be surprised by a sudden NH exception telling me that I have forgotten one column in my .HBM file. That is why I have decided to reimplement this class so that it would not throw HibernateException, but rather return a list of possible issues.

The following code is based on original NHibernate.Tool.hbm2ddl.SchemaValidator class. Instead of throwing HibernateException it returns a list of strings. The code was tested with Sql Server 2005.

using System;
using System.Collections.Generic;
using System.Data.Common;
using NHibernate;
using NHibernate.Cfg;
using NHibernate.Dialect;
using NHibernate.Dialect.Schema;
using NHibernate.Engine;
using NHibernate.Id;
using NHibernate.Mapping;
using NHibernate.Tool.hbm2ddl;
using NHibernate.Util;

namespace MyApplication.Server.DAO.SchemaValidation 
{ 
    class SchemaValidator
    {
        private readonly Configuration configuration;
        private readonly IConnectionHelper connectionHelper;
        private readonly Dialect dialect;

        public SchemaValidator(Configuration cfg) : this(cfg, cfg.Properties) { }

        public SchemaValidator(Configuration cfg, IDictionary<string, string> connectionProperties)
        {
            configuration = cfg;
            dialect = Dialect.GetDialect(connectionProperties);
            IDictionary<string, string> props = new Dictionary<string, string>(dialect.DefaultProperties);
            foreach (var prop in connectionProperties)
            {
                props[prop.Key] = prop.Value;
            }
            connectionHelper = new ManagedProviderConnectionHelper(props);
        }

        public SchemaValidator(Configuration cfg, Settings settings)
        {
            configuration = cfg;
            dialect = settings.Dialect;
            connectionHelper = new SuppliedConnectionProviderConnectionHelper(settings.ConnectionProvider);
        }

        public IList<string> Validate()
        {
            try
            {
                DatabaseMetadata meta;
                try
                {
                    connectionHelper.Prepare();
                    DbConnection connection = connectionHelper.Connection;
                    meta = new DatabaseMetadata(connection, dialect, false);
                }
                catch (Exception sqle)
                {
                    throw;
                }

                return ValidateSchema(dialect, meta);
            }
            catch (Exception e)
            {
                throw;
            }
            finally
            {
                try
                {
                    connectionHelper.Release();
                }
                catch (Exception e)
                {
                	throw;
                }
            }
        }
        
        private IList<string> ValidateSchema(
            Dialect dialect, DatabaseMetadata databaseMetadata)
        {
            IList<string> problems = new List<string>();

            string defaultCatalog = PropertiesHelper.GetString(NHibernate.Cfg.Environment.DefaultCatalog, 
                configuration.Properties, null);
            string defaultSchema = PropertiesHelper.GetString(NHibernate.Cfg.Environment.DefaultSchema, 
                configuration.Properties, null);

            IMapping mapping = configuration.BuildMapping();
            ICollection<PersistentClass> list = configuration.ClassMappings;
            foreach (PersistentClass pc in list)
            {
                try
                {
                    var table = pc.Table;
                    if (table.IsPhysicalTable)
                    {
                        ITableMetadata tableInfo = databaseMetadata.GetTableMetadata(
                            table.Name,
                            table.Schema ?? defaultSchema,
                            table.Catalog ?? defaultCatalog,
                            table.IsQuoted);
                        if (tableInfo == null)
                            problems.Add(string.Format("Missing table: {0}", table.Name));
                        else
                            ValidateColumns(problems, table, dialect, mapping, tableInfo);
                    }
                }
                catch (HibernateException ex)
                {
                    problems.Add(ex.Message);
                }
            }

            var persistenceIdentifierGenerators = IterateGenerators(dialect);
            foreach (var generator in persistenceIdentifierGenerators)
            {
                string key = generator.GeneratorKey();
                if (!databaseMetadata.IsSequence(key) && !databaseMetadata.IsTable(key))
                {
                    problems.Add(string.Format("Missing sequence or table: {0}", key));
                }
            }
            return problems;
        }

        private IEnumerable<IPersistentIdentifierGenerator> IterateGenerators(Dialect dialect)
        {
            var generators = new Dictionary<string, IPersistentIdentifierGenerator>();
            string defaultCatalog = PropertiesHelper.GetString(NHibernate.Cfg.Environment.DefaultCatalog,
                configuration.Properties, null);
            string defaultSchema = PropertiesHelper.GetString(NHibernate.Cfg.Environment.DefaultSchema,
                configuration.Properties, null);

            foreach (var pc in configuration.ClassMappings)
            {
                if (!pc.IsInherited)
                {
                    var ig =
                        pc.Identifier.CreateIdentifierGenerator(dialect, defaultCatalog, defaultSchema, (RootClass)pc) as
                        IPersistentIdentifierGenerator;

                    if (ig != null)
                    {
                        generators[ig.GeneratorKey()] = ig;
                    }
                }
            }

            foreach (var collection in configuration.CollectionMappings)
            {
                if (collection.IsIdentified)
                {
                    var ig =
                        ((IdentifierCollection)collection).Identifier.CreateIdentifierGenerator(dialect, defaultCatalog, defaultSchema,
                                                                                                 null) as IPersistentIdentifierGenerator;

                    if (ig != null)
                    {
                        generators[ig.GeneratorKey()] = ig;
                    }
                }
            }

            return generators.Values;
        }


        private void ValidateColumns(
            IList<string> problems, 
            Table table,
            Dialect dialect,
            IMapping mapping,
            ITableMetadata tableInfo)
        {
            IEnumerable<Column> iter = table.ColumnIterator;
            foreach (Column column in iter)
            {
                IColumnMetadata columnInfo = tableInfo.GetColumnMetadata(column.Name);

                if (columnInfo == null)
                {
                    problems.Add(string.Format("Missing column: {0} in {1}", column.Name,
                        NHibernate.Mapping.Table.Qualify(tableInfo.Catalog, tableInfo.Schema, tableInfo.Name)));
                }
                else
                {
                    bool typesMatch = column.GetSqlType(dialect, mapping).ToLower().StartsWith(columnInfo.TypeName.ToLower());
                    if (!typesMatch)
                    {
                        problems.Add(string.Format("Wrong column type in {0} for column {1}. Found: {2}, Expected {3}",
                                                                   NHibernate.Mapping.Table.Qualify(tableInfo.Catalog, tableInfo.Schema, tableInfo.Name),
                                                                   column.Name, columnInfo.TypeName.ToLower(),
                                                                   column.GetSqlType(dialect, mapping)));
                    }
                }
            }
        }

    }
}