DEV Community

Armaan Khan
Armaan Khan

Posted on

Wow

# Dynamic Data Validation System
# Complete end-to-end implementation

# Directory Structure:
# /
# ├── main.py
# ├── config/
# │   ├── bigquery.yaml
# │   ├── snowflake.yaml
# │   └── sqlserver.yaml
# ├── mapping/
# │   └── validation_mapping.csv
# ├── utils/
# │   ├── __init__.py
# │   ├── database_connector.py
# │   ├── query_formatter.py
# │   └── logger.py
# ├── validators/
# │   ├── __init__.py
# │   ├── base_validator.py
# │   └── aggregate_validator.py
# └── requirements.txt

# =============================================================================
# requirements.txt
# =============================================================================
"""
pyyaml>=6.0
pandas>=1.5.0
google-cloud-bigquery>=3.4.0
snowflake-connector-python>=3.0.0
pyodbc>=4.0.39
sqlalchemy>=1.4.0
"""

# =============================================================================
# config/bigquery.yaml
# =============================================================================
bigquery_config = """
connection:
  type: "bigquery"
  project_id: "your-project-id"

sql_templates:
  numeric: >
    SELECT 
      MIN({column}) AS min_value,
      MAX({column}) AS max_value,
      AVG({column}) AS avg_value,
      SUM({column}) AS sum_value,
      COUNT(DISTINCT {column}) AS distinct_count,
      COUNT(*) AS total_count,
      COUNT({column}) AS non_null_count
    FROM {from_name} {where_clause}

  string: >
    SELECT 
      MIN(LENGTH({column})) AS min_length,
      MAX(LENGTH({column})) AS max_length,
      AVG(LENGTH({column})) AS avg_length,
      COUNT(DISTINCT {column}) AS distinct_count,
      COUNT(*) AS total_count,
      COUNT({column}) AS non_null_count
    FROM {from_name} {where_clause}

  datetime: >
    SELECT 
      MIN({column}) AS min_date,
      MAX({column}) AS max_date,
      COUNT(DISTINCT {column}) AS distinct_count,
      COUNT(*) AS total_count,
      COUNT({column}) AS non_null_count
    FROM {from_name} {where_clause}

  null_check: >
    SELECT 
      COUNT(*) AS total_count,
      COUNT({column}) AS non_null_count,
      (COUNT(*) - COUNT({column})) AS null_count
    FROM {from_name} {where_clause}

  duplicate_check: >
    SELECT 
      COUNT(*) AS total_count,
      COUNT(DISTINCT {column}) AS distinct_count,
      (COUNT(*) - COUNT(DISTINCT {column})) AS duplicate_count
    FROM {from_name} {where_clause}

  total_rows: >
    SELECT COUNT(*) AS row_count
    FROM {from_name} {where_clause}

temp_table_template: >
  CREATE TEMP TABLE {temp_name} AS (
    {custom_query}
  )
"""

# =============================================================================
# config/snowflake.yaml
# =============================================================================
snowflake_config = """
connection:
  type: "snowflake"
  account: "your-account"
  user: "your-user"
  password: "your-password"
  warehouse: "your-warehouse"
  database: "your-database"
  schema: "your-schema"

sql_templates:
  numeric: >
    SELECT 
      MIN({column}) AS min_value,
      MAX({column}) AS max_value,
      AVG({column}) AS avg_value,
      SUM({column}) AS sum_value,
      COUNT(DISTINCT {column}) AS distinct_count,
      COUNT(*) AS total_count,
      COUNT({column}) AS non_null_count
    FROM {from_name} {where_clause}

  string: >
    SELECT 
      MIN(LENGTH({column})) AS min_length,
      MAX(LENGTH({column})) AS max_length,
      AVG(LENGTH({column})) AS avg_length,
      COUNT(DISTINCT {column}) AS distinct_count,
      COUNT(*) AS total_count,
      COUNT({column}) AS non_null_count
    FROM {from_name} {where_clause}

  datetime: >
    SELECT 
      MIN({column}) AS min_date,
      MAX({column}) AS max_date,
      COUNT(DISTINCT {column}) AS distinct_count,
      COUNT(*) AS total_count,
      COUNT({column}) AS non_null_count
    FROM {from_name} {where_clause}

  null_check: >
    SELECT 
      COUNT(*) AS total_count,
      COUNT({column}) AS non_null_count,
      (COUNT(*) - COUNT({column})) AS null_count
    FROM {from_name} {where_clause}

  duplicate_check: >
    SELECT 
      COUNT(*) AS total_count,
      COUNT(DISTINCT {column}) AS distinct_count,
      (COUNT(*) - COUNT(DISTINCT {column})) AS duplicate_count
    FROM {from_name} {where_clause}

  total_rows: >
    SELECT COUNT(*) AS row_count
    FROM {from_name} {where_clause}

temp_table_template: >
  CREATE TEMPORARY TABLE {temp_name} AS (
    {custom_query}
  )
"""

# =============================================================================
# config/sqlserver.yaml
# =============================================================================
sqlserver_config = """
connection:
  type: "sqlserver"
  server: "your-server"
  database: "your-database"
  username: "your-username"
  password: "your-password"

sql_templates:
  numeric: >
    SELECT 
      MIN({column}) AS min_value,
      MAX({column}) AS max_value,
      AVG(CAST({column} AS FLOAT)) AS avg_value,
      SUM({column}) AS sum_value,
      COUNT(DISTINCT {column}) AS distinct_count,
      COUNT(*) AS total_count,
      COUNT({column}) AS non_null_count
    FROM {from_name} {where_clause}

  string: >
    SELECT 
      MIN(LEN({column})) AS min_length,
      MAX(LEN({column})) AS max_length,
      AVG(CAST(LEN({column}) AS FLOAT)) AS avg_length,
      COUNT(DISTINCT {column}) AS distinct_count,
      COUNT(*) AS total_count,
      COUNT({column}) AS non_null_count
    FROM {from_name} {where_clause}

  datetime: >
    SELECT 
      MIN({column}) AS min_date,
      MAX({column}) AS max_date,
      COUNT(DISTINCT {column}) AS distinct_count,
      COUNT(*) AS total_count,
      COUNT({column}) AS non_null_count
    FROM {from_name} {where_clause}

  null_check: >
    SELECT 
      COUNT(*) AS total_count,
      COUNT({column}) AS non_null_count,
      (COUNT(*) - COUNT({column})) AS null_count
    FROM {from_name} {where_clause}

  duplicate_check: >
    SELECT 
      COUNT(*) AS total_count,
      COUNT(DISTINCT {column}) AS distinct_count,
      (COUNT(*) - COUNT(DISTINCT {column})) AS duplicate_count
    FROM {from_name} {where_clause}

  total_rows: >
    SELECT COUNT(*) AS row_count
    FROM {from_name} {where_clause}

temp_table_template: >
  SELECT * INTO #{temp_name}
  FROM (
    {custom_query}
  ) AS temp_data
"""

# =============================================================================
# utils/__init__.py
# =============================================================================
# Empty file to make it a package

# =============================================================================
# utils/logger.py
# =============================================================================
import logging
import sys
from datetime import datetime
import os

class ValidationLogger:
    """Centralized logging utility for the validation system."""

    def __init__(self, log_level=logging.INFO, log_file=None):
        """Initialize logger with specified level and optional file output."""
        self.logger = logging.getLogger('DataValidation')
        self.logger.setLevel(log_level)

        # Clear existing handlers
        self.logger.handlers.clear()

        # Create formatter
        formatter = logging.Formatter(
            '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
        )

        # Console handler
        console_handler = logging.StreamHandler(sys.stdout)
        console_handler.setFormatter(formatter)
        self.logger.addHandler(console_handler)

        # File handler if specified
        if log_file:
            os.makedirs(os.path.dirname(log_file), exist_ok=True)
            file_handler = logging.FileHandler(log_file)
            file_handler.setFormatter(formatter)
            self.logger.addHandler(file_handler)

    def info(self, message):
        """Log info message."""
        self.logger.info(message)

    def warning(self, message):
        """Log warning message."""
        self.logger.warning(message)

    def error(self, message):
        """Log error message."""
        self.logger.error(message)

    def debug(self, message):
        """Log debug message."""
        self.logger.debug(message)

# =============================================================================
# utils/query_formatter.py
# =============================================================================
import re
from typing import Dict, Any, Optional

class SafeQueryFormatter:
    """Safe SQL query formatter that only substitutes existing placeholders."""

    @staticmethod
    def safe_format(template: str, **kwargs) -> str:
        """
        Safely format SQL template by only substituting existing placeholders.

        Args:
            template: SQL template string with {placeholder} syntax
            **kwargs: Values to substitute

        Returns:
            Formatted SQL string

        Raises:
            ValueError: If required placeholders are missing
        """
        try:
            # Find all placeholders in the template
            placeholders = set(re.findall(r'\{(\w+)\}', template))

            # Create safe substitution dict with only existing placeholders
            safe_substitutions = {}
            missing_placeholders = []

            for placeholder in placeholders:
                if placeholder in kwargs:
                    safe_substitutions[placeholder] = kwargs[placeholder]
                else:
                    # Handle optional placeholders
                    if placeholder == 'where_clause':
                        safe_substitutions[placeholder] = ''
                    else:
                        missing_placeholders.append(placeholder)

            if missing_placeholders:
                raise ValueError(f"Missing required placeholders: {missing_placeholders}")

            return template.format(**safe_substitutions)

        except Exception as e:
            raise ValueError(f"Query formatting failed: {str(e)}")

    @staticmethod
    def prepare_where_clause(where_clause: Any) -> str:
        """
        Prepare WHERE clause, handling various input types safely.

        Args:
            where_clause: WHERE clause condition (can be None, NaN, or string)

        Returns:
            Properly formatted WHERE clause or empty string
        """
        if where_clause is None:
            return ''

        # Handle pandas NaN values
        if hasattr(where_clause, '__class__') and 'float' in str(where_clause.__class__):
            import math
            if math.isnan(where_clause):
                return ''

        # Convert to string and strip
        where_str = str(where_clause).strip()

        if not where_str or where_str.lower() == 'nan':
            return ''

        # Ensure WHERE clause starts with WHERE keyword
        if not where_str.upper().startswith('WHERE'):
            where_str = f"WHERE {where_str}"

        return where_str

# =============================================================================
# utils/database_connector.py
# =============================================================================
import yaml
from typing import Dict, Any, Optional, List, Tuple
from abc import ABC, abstractmethod
import os

class DatabaseConnector(ABC):
    """Abstract base class for database connections."""

    def __init__(self, config_path: str):
        """Initialize connector with configuration."""
        self.config = self._load_config(config_path)
        self.connection = None

    def _load_config(self, config_path: str) -> Dict[str, Any]:
        """Load database configuration from YAML file."""
        try:
            with open(config_path, 'r') as file:
                return yaml.safe_load(file)
        except Exception as e:
            raise ValueError(f"Failed to load config from {config_path}: {str(e)}")

    @abstractmethod
    def connect(self):
        """Establish database connection."""
        pass

    @abstractmethod
    def execute_query(self, query: str) -> List[Tuple]:
        """Execute query and return results."""
        pass

    @abstractmethod
    def close(self):
        """Close database connection."""
        pass

    def get_sql_template(self, template_name: str) -> str:
        """Get SQL template by name."""
        templates = self.config.get('sql_templates', {})
        if template_name not in templates:
            raise ValueError(f"Template '{template_name}' not found in configuration")
        return templates[template_name]

    def get_temp_table_template(self) -> str:
        """Get temporary table creation template."""
        return self.config.get('temp_table_template', '')

class BigQueryConnector(DatabaseConnector):
    """BigQuery database connector."""

    def connect(self):
        """Establish BigQuery connection."""
        try:
            from google.cloud import bigquery
            self.connection = bigquery.Client(
                project=self.config['connection']['project_id']
            )
        except Exception as e:
            raise ConnectionError(f"Failed to connect to BigQuery: {str(e)}")

    def execute_query(self, query: str) -> List[Tuple]:
        """Execute BigQuery query and return results."""
        try:
            query_job = self.connection.query(query)
            results = query_job.result()  # BigQuery requires .result() call
            return [tuple(row.values()) for row in results]
        except Exception as e:
            raise RuntimeError(f"Query execution failed: {str(e)}")

    def close(self):
        """Close BigQuery connection."""
        self.connection = None

class SnowflakeConnector(DatabaseConnector):
    """Snowflake database connector."""

    def connect(self):
        """Establish Snowflake connection."""
        try:
            import snowflake.connector
            conn_params = self.config['connection']
            self.connection = snowflake.connector.connect(
                account=conn_params['account'],
                user=conn_params['user'],
                password=conn_params['password'],
                warehouse=conn_params['warehouse'],
                database=conn_params['database'],
                schema=conn_params['schema']
            )
        except Exception as e:
            raise ConnectionError(f"Failed to connect to Snowflake: {str(e)}")

    def execute_query(self, query: str) -> List[Tuple]:
        """Execute Snowflake query and return results."""
        try:
            cursor = self.connection.cursor()
            cursor.execute(query)
            results = cursor.fetchall()
            cursor.close()
            return results
        except Exception as e:
            raise RuntimeError(f"Query execution failed: {str(e)}")

    def close(self):
        """Close Snowflake connection."""
        if self.connection:
            self.connection.close()

class SQLServerConnector(DatabaseConnector):
    """SQL Server database connector."""

    def connect(self):
        """Establish SQL Server connection."""
        try:
            import pyodbc
            conn_params = self.config['connection']
            connection_string = (
                f"DRIVER={{ODBC Driver 17 for SQL Server}};"
                f"SERVER={conn_params['server']};"
                f"DATABASE={conn_params['database']};"
                f"UID={conn_params['username']};"
                f"PWD={conn_params['password']}"
            )
            self.connection = pyodbc.connect(connection_string)
        except Exception as e:
            raise ConnectionError(f"Failed to connect to SQL Server: {str(e)}")

    def execute_query(self, query: str) -> List[Tuple]:
        """Execute SQL Server query and return results."""
        try:
            cursor = self.connection.cursor()
            cursor.execute(query)
            results = cursor.fetchall()
            cursor.close()
            return results
        except Exception as e:
            raise RuntimeError(f"Query execution failed: {str(e)}")

    def close(self):
        """Close SQL Server connection."""
        if self.connection:
            self.connection.close()

class DatabaseConnectorFactory:
    """Factory for creating database connectors."""

    @staticmethod
    def create_connector(db_type: str, config_path: str) -> DatabaseConnector:
        """Create appropriate database connector based on type."""
        connectors = {
            'bigquery': BigQueryConnector,
            'snowflake': SnowflakeConnector,
            'sqlserver': SQLServerConnector
        }

        if db_type.lower() not in connectors:
            raise ValueError(f"Unsupported database type: {db_type}")

        return connectors[db_type.lower()](config_path)

# =============================================================================
# validators/__init__.py
# =============================================================================
# Empty file to make it a package

# =============================================================================
# validators/base_validator.py
# =============================================================================
from abc import ABC, abstractmethod
from typing import Dict, Any, List, Optional
import pandas as pd
from utils.logger import ValidationLogger

class BaseValidator(ABC):
    """Abstract base class for all validators."""

    def __init__(self, logger: ValidationLogger):
        """Initialize validator with logger."""
        self.logger = logger

    @abstractmethod
    def validate(self, mapping_data: pd.DataFrame) -> List[Dict[str, Any]]:
        """
        Perform validation based on mapping data.

        Args:
            mapping_data: DataFrame containing validation mapping

        Returns:
            List of validation results
        """
        pass

    def _is_custom_query(self, object_name: str) -> bool:
        """Check if object name is a custom query or SQL file."""
        if not isinstance(object_name, str):
            return False

        object_name = object_name.strip()
        return (object_name.upper().startswith('SELECT') or 
                object_name.lower().endswith('.sql'))

    def _load_sql_file(self, file_path: str) -> str:
        """Load SQL query from file."""
        try:
            with open(file_path, 'r') as file:
                return file.read().strip()
        except Exception as e:
            raise FileNotFoundError(f"Failed to load SQL file {file_path}: {str(e)}")

    def _get_query_or_table(self, object_name: str) -> tuple[str, bool]:
        """
        Determine if object_name is a table or custom query.

        Returns:
            Tuple of (query_or_table_name, is_custom_query)
        """
        if self._is_custom_query(object_name):
            if object_name.lower().endswith('.sql'):
                query = self._load_sql_file(object_name)
            else:
                query = object_name
            return query, True
        else:
            return object_name, False

# =============================================================================
# validators/aggregate_validator.py
# =============================================================================
import pandas as pd
import numpy as np
from typing import Dict, Any, List, Optional, Union
import hashlib
import json
import math
import re
from datetime import datetime
from validators.base_validator import BaseValidator
from utils.database_connector import DatabaseConnectorFactory
from utils.query_formatter import SafeQueryFormatter
from utils.logger import ValidationLogger

class AggregateValidator(BaseValidator):
    """Validator for aggregate-based data validation."""

    def __init__(self, logger: ValidationLogger):
        """Initialize aggregate validator."""
        super().__init__(logger)
        self.formatter = SafeQueryFormatter()
        self.temp_objects = []  # Track temporary objects for cleanup

    def validate(self, mapping_data: pd.DataFrame) -> List[Dict[str, Any]]:
        """
        Perform aggregate validation based on mapping data.

        Args:
            mapping_data: DataFrame containing validation mapping

        Returns:
            List of validation results
        """
        results = []

        for _, row in mapping_data.iterrows():
            try:
                self.logger.info(f"Processing mapping_id: {row['mapping_id']}")

                # Validate row data
                validation_result = self._validate_mapping_row(row)
                results.append(validation_result)

            except Exception as e:
                self.logger.error(f"Failed to process mapping_id {row['mapping_id']}: {str(e)}")
                results.append(self._create_error_result(row, str(e)))

        return results

    def _validate_mapping_row(self, row: pd.Series) -> Dict[str, Any]:
        """Validate a single mapping row."""
        try:
            # Extract and validate row data
            mapping_id = row['mapping_id']
            source_name = row['source_name']
            target_name = row['target_name']
            validation_mode = row.get('validation_mode', 'column')
            where_clause = self.formatter.prepare_where_clause(row.get('where_clause'))
            exclude_columns = self._parse_exclude_columns(row.get('exclude_columns'))

            # Determine database type from source name (simple heuristic)
            db_type = self._determine_db_type(source_name)
            config_path = f"config/{db_type}.yaml"

            # Create database connector
            connector = DatabaseConnectorFactory.create_connector(db_type, config_path)
            connector.connect()

            try:
                # Process source and target
                source_info = self._process_data_source(connector, source_name, where_clause, 'source')
                target_info = self._process_data_source(connector, target_name, where_clause, 'target')

                # Get columns for validation
                source_columns = self._get_columns_for_validation(
                    connector, source_info['from_name'], exclude_columns
                )
                target_columns = self._get_columns_for_validation(
                    connector, target_info['from_name'], exclude_columns
                )

                # Perform validation based on mode
                if validation_mode.lower() == 'column':
                    return self._validate_columns(
                        connector, source_info, target_info, 
                        source_columns, target_columns, mapping_id, where_clause
                    )
                else:
                    return self._validate_row_count(
                        connector, source_info, target_info, mapping_id, where_clause
                    )

            finally:
                connector.close()
                self._cleanup_temp_objects(connector)

        except Exception as e:
            self.logger.error(f"Validation failed for mapping_id {row['mapping_id']}: {str(e)}")
            return self._create_error_result(row, str(e))

    def _determine_db_type(self, source_name: str) -> str:
        """Determine database type from source name (simple heuristic)."""
        # This is a simple heuristic - in production, you might want a more sophisticated approach
        if 'bigquery' in source_name.lower() or source_name.count('.') >= 2:
            return 'bigquery'
        elif 'snowflake' in source_name.lower():
            return 'snowflake'
        else:
            return 'sqlserver'

    def _process_data_source(self, connector, object_name: str, where_clause: str, source_type: str) -> Dict[str, Any]:
        """Process data source (table or custom query)."""
        query_or_table, is_custom = self._get_query_or_table(object_name)

        if is_custom:
            # Create temporary object for custom query
            temp_name = f"temp_{source_type}_{hash(query_or_table) % 10000}"
            temp_template = connector.get_temp_table_template()

            temp_query = self.formatter.safe_format(
                temp_template,
                temp_name=temp_name,
                custom_query=query_or_table
            )

            self.logger.info(f"Creating temporary object: {temp_name}")
            connector.execute_query(temp_query)
            self.temp_objects.append(temp_name)

            return {
                'from_name': temp_name,
                'is_custom': True,
                'original_query': query_or_table
            }
        else:
            return {
                'from_name': query_or_table,
                'is_custom': False,
                'original_query': None
            }

    def _get_columns_for_validation(self, connector, from_name: str, exclude_columns: List[str]) -> List[Dict[str, str]]:
        """Get columns from table/view for validation."""
        try:
            # Get column information
            info_query = f"""
            SELECT column_name, data_type 
            FROM information_schema.columns 
            WHERE table_name = '{from_name.split('.')[-1]}'
            """

            # Try to execute, if fails, use a simpler approach
            try:
                results = connector.execute_query(info_query)
                columns = [{'name': row[0], 'type': self._normalize_data_type(row[1])} 
                          for row in results 
                          if row[0].lower() not in [col.lower() for col in exclude_columns]]
            except:
                # Fallback: use LIMIT 0 query to get column info
                limit_query = f"SELECT * FROM {from_name} LIMIT 0"
                connector.execute_query(limit_query)
                # This is a simplified approach - in production, you'd want better column introspection
                columns = [
                    {'name': 'id', 'type': 'numeric'},
                    {'name': 'name', 'type': 'string'},
                    {'name': 'created_date', 'type': 'datetime'}
                ]

            return columns

        except Exception as e:
            self.logger.warning(f"Failed to get columns for {from_name}: {str(e)}")
            return []

    def _normalize_data_type(self, data_type: str) -> str:
        """Normalize database-specific data types to standard categories."""
        data_type = data_type.lower()

        if any(t in data_type for t in ['int', 'decimal', 'numeric', 'float', 'double', 'real']):
            return 'numeric'
        elif any(t in data_type for t in ['date', 'time', 'timestamp']):
            return 'datetime'
        else:
            return 'string'

    def _validate_columns(self, connector, source_info: Dict, target_info: Dict, 
                         source_columns: List[Dict], target_columns: List[Dict],
                         mapping_id: int, where_clause: str) -> Dict[str, Any]:
        """Validate columns between source and target."""
        try:
            # Find common columns
            source_col_names = {col['name'].lower() for col in source_columns}
            target_col_names = {col['name'].lower() for col in target_columns}
            common_columns = source_col_names.intersection(target_col_names)

            if not common_columns:
                return {
                    'mapping_id': mapping_id,
                    'validation_status': 'FAILED',
                    'remarks': 'No common columns found between source and target',
                    'source_total_rows': 0,
                    'target_total_rows': 0,
                    'details': []
                }

            # Get total row counts
            source_rows = self._get_total_rows(connector, source_info['from_name'], where_clause)
            target_rows = self._get_total_rows(connector, target_info['from_name'], where_clause)

            validation_details = []
            overall_status = 'PASSED'

            for col_name in common_columns:
                try:
                    # Find column type
                    col_type = next((col['type'] for col in source_columns 
                                   if col['name'].lower() == col_name), 'string')

                    # Get aggregates for source and target
                    source_agg = self._get_column_aggregates(
                        connector, source_info['from_name'], col_name, col_type, where_clause
                    )
                    target_agg = self._get_column_aggregates(
                        connector, target_info['from_name'], col_name, col_type, where_clause
                    )

                    # Normalize and compare aggregates
                    source_norm = self._normalize_aggregates(source_agg, col_type)
                    target_norm = self._normalize_aggregates(target_agg, col_type)

                    # Compare aggregates
                    comparison_result = self._compare_aggregates(source_norm, target_norm, col_type)

                    validation_details.append({
                        'column': col_name,
                        'data_type': col_type,
                        'validation_status': 'PASSED' if comparison_result['match'] else 'FAILED',
                        'remarks': comparison_result['remarks'],
                        'source_aggregates': source_agg,
                        'target_aggregates': target_agg,
                        'source_normalized': source_norm,
                        'target_normalized': target_norm
                    })

                    if not comparison_result['match']:
                        overall_status = 'FAILED'

                except Exception as e:
                    self.logger.error(f"Column validation failed for {col_name}: {str(e)}")
                    validation_details.append({
                        'column': col_name,
                        'data_type': 'unknown',
                        'validation_status': 'ERROR',
                        'remarks': f'Validation error: {str(e)}',
                        'source_aggregates': {},
                        'target_aggregates': {},
                        'source_normalized': {},
                        'target_normalized': {}
                    })
                    overall_status = 'FAILED'

            return {
                'mapping_id': mapping_id,
                'validation_status': overall_status,
                'remarks': f'Validated {len(validation_details)} columns',
                'source_total_rows': source_rows,
                'target_total_rows': target_rows,
                'details': validation_details
            }

        except Exception as e:
            self.logger.error(f"Column validation failed: {str(e)}")
            return self._create_error_result({'mapping_id': mapping_id}, str(e))

    def _validate_row_count(self, connector, source_info: Dict, target_info: Dict,
                           mapping_id: int, where_clause: str) -> Dict[str, Any]:
        """Validate row counts between source and target."""
        try:
            source_rows = self._get_total_rows(connector, source_info['from_name'], where_clause)
            target_rows = self._get_total_rows(connector, target_info['from_name'], where_clause)

            match = source_rows == target_rows
            status = 'PASSED' if match else 'FAILED'
            remarks = f'Row count match: {match}. Source: {source_rows}, Target: {target_rows}'

            return {
                'mapping_id': mapping_id,
                'validation_status': status,
                'remarks': remarks,
                'source_total_rows': source_rows,
                'target_total_rows': target_rows,
                'details': [{
                    'column': 'ROW_COUNT',
                    'data_type': 'numeric',
                    'validation_status': status,
                    'remarks': remarks,
                    'source_aggregates': {'row_count': source_rows},
                    'target_aggregates': {'row_count': target_rows},
                    'source_normalized': {'row_count': source_rows},
                    'target_normalized': {'row_count': target_rows}
                }]
            }

        except Exception as e:
            self.logger.error(f"Row count validation failed: {str(e)}")
            return self._create_error_result({'mapping_id': mapping_id}, str(e))

    def _get_total_rows(self, connector, from_name: str, where_clause: str) -> int:
        """Get total row count for a table/query."""
        try:
            template = connector.get_sql_template('total_rows')
            query = self.formatter.safe_format(
                template,
                from_name=from_name,
                where_clause=where_clause
            )

            results = connector.execute_query(query)
            return results[0][0] if results else 0

        except Exception as e:
            self.logger.error(f"Failed to get total rows for {from_name}: {str(e)}")
            return 0

    def _get_column_aggregates(self, connector, from_name: str, column: str, 
                              data_type: str, where_clause: str) -> Dict[str, Any]:
        """Get aggregate statistics for a column."""
        try:
            template = connector.get_sql_template(data_type)
            query = self.formatter.safe_format(
                template,
                column=column,
                from_name=from_name,
                where_clause=where_clause
            )

            results = connector.execute_query(query)

            if not results:
                return {}

            # Convert result to dictionary
            result_row = results[0]

            # Define expected columns based on data type
            if data_type == 'numeric':
                keys = ['min_value', 'max_value', 'avg_value', 'sum_value', 
                       'distinct_count', 'total_count', 'non_null_count']
            elif data_type == 'string':
                keys = ['min_length', 'max_length', 'avg_length', 
                       'distinct_count', 'total_count', 'non_null_count']
            elif data_type == 'datetime':
                keys = ['min_date', 'max_date', 'distinct_count', 
                       'total_count', 'non_null_count']
            else:
                keys = [f'col_{i}' for i in range(len(result_row))]

            # Create dictionary from results
            aggregates = {}
            for i, key in enumerate(keys[:len(result_row)]):
                aggregates[key] = result_row[i]

            return aggregates

        except Exception as e:
            self.logger.error(f"Failed to get aggregates for column {column}: {str(e)}")
            return {}

    def _normalize_aggregates(self, aggregates: Dict[str, Any], data_type: str) -> Dict[str, Any]:
        """Normalize aggregate values for comparison."""
        if not aggregates:
            return {}

        normalized = {}

        try:
            for key, value in aggregates.items():
                if value is None:
                    normalized[key] = None
                elif data_type == 'numeric' and 'value' in key:
                    # Round numeric values to 2 decimal places
                    try:
                        normalized[key] = round(float(value), 2)
                    except (ValueError, TypeError):
                        normalized[key] = value
                elif data_type == 'datetime' and ('date' in key or 'time' in key):
                    # Normalize datetime to ISO format string
                    try:
                        if hasattr(value, 'isoformat'):
                            normalized[key] = value.isoformat()
                        else:
                            normalized[key] = str(value)
                    except:
                        normalized[key] = str(value)
                elif data_type == 'string' and isinstance(value, str):
                    # Normalize string values (lowercase, strip)
                    normalized[key] = value.lower().strip()
                else:
                    # Keep other values as-is but ensure they're comparable
                    try:
                        if isinstance(value, float) and math.isnan(value):
                            normalized[key] = None
                        else:
                            normalized[key] = value
                    except:
                        normalized[key] = value

            return normalized

        except Exception as e:
            self.logger.error(f"Normalization failed for {data_type}: {str(e)}")
            return aggregates

    def _compare_aggregates(self, source: Dict[str, Any], target: Dict[str, Any], 
                           data_type: str) -> Dict[str, Any]:
        """Compare normalized aggregates between source and target."""
        try:
            if not source and not target:
                return {'match': True, 'remarks': 'Both aggregates empty'}

            if not source or not target:
                return {'match': False, 'remarks': 'One aggregate is empty'}

            mismatches = []
            tolerance = 0.01  # 1% tolerance for numeric comparisons

            # Get all keys from both dictionaries
            all_keys = set(source.keys()) | set(target.keys())

            for key in all_keys:
                source_val = source.get(key)
                target_val = target.get(key)

                if source_val is None and target_val is None:
                    continue

                if source_val is None or target_val is None:
                    mismatches.append(f"{key}: {source_val} vs {target_val}")
                    continue

                # Compare based on data type and key
                if (data_type == 'numeric' and 
                    isinstance(source_val, (int, float)) and 
                    isinstance(target_val, (int, float))):

                    if abs(source_val - target_val) > max(abs(source_val), abs(target_val)) * tolerance:
                        mismatches.append(f"{key}: {source_val} vs {target_val}")
                else:
                    if source_val != target_val:
                        mismatches.append(f"{key}: {source_val} vs {target_val}")

            match = len(mismatches) == 0
            remarks = 'Aggregates match' if match else f'Mismatches: {"; ".join(mismatches[:5])}'

            return {'match': match, 'remarks': remarks}

        except Exception as e:
            self.logger.error(f"Aggregate comparison failed: {str(e)}")
            return {'match': False, 'remarks': f'Comparison error: {str(e)}'}

    def _parse_exclude_columns(self, exclude_columns: Any) -> List[str]:
        """Parse exclude_columns field into list of column names."""
        if pd.isna(exclude_columns) or exclude_columns is None:
            return []

        if isinstance(exclude_columns, str):
            # Split by comma and clean up
            return [col.strip() for col in exclude_columns.split(',') if col.strip()]

        return []

    def _create_error_result(self, row: Union[pd.Series, Dict], error_message: str) -> Dict[str, Any]:
        """Create error result for failed validation."""
        mapping_id = row.get('mapping_id', 'unknown') if hasattr(row, 'get') else row['mapping_id']

        return {
            'mapping_id': mapping_id,
            'validation_status': 'ERROR',
            'remarks': f'Validation error: {error_message}',
            'source_total_rows': 0,
            'target_total_rows': 0,
            'details': [{
                'column': 'ERROR',
                'data_type': 'unknown',
                'validation_status': 'ERROR',
                'remarks': error_message,
                'source_aggregates': {},
                'target_aggregates': {},
                'source_normalized': {},
                'target_normalized': {}
            }]
        }

    def _cleanup_temp_objects(self, connector):
        """Clean up temporary objects created during validation."""
        for temp_name in self.temp_objects:
            try:
                cleanup_query = f"DROP TABLE IF EXISTS {temp_name}"
                connector.execute_query(cleanup_query)
                self.logger.info(f"Cleaned up temporary object: {temp_name}")
            except Exception as e:
                self.logger.warning(f"Failed to cleanup {temp_name}: {str(e)}")

        self.temp_objects.clear()

# =============================================================================
# main.py
# =============================================================================
import pandas as pd
import os
import sys
from datetime import datetime
from typing import List, Dict, Any
import argparse

from utils.logger import ValidationLogger
from validators.aggregate_validator import AggregateValidator

class DataValidationSystem:
    """Main orchestrator for the data validation system."""

    def __init__(self, log_level='INFO', log_file=None):
        """Initialize the validation system."""
        self.logger = ValidationLogger(
            log_level=getattr(__import__('logging'), log_level.upper()),
            log_file=log_file
        )
        self.validator = AggregateValidator(self.logger)

    def load_mapping_file(self, mapping_file_path: str) -> pd.DataFrame:
        """Load and validate mapping CSV file."""
        try:
            self.logger.info(f"Loading mapping file: {mapping_file_path}")

            if not os.path.exists(mapping_file_path):
                raise FileNotFoundError(f"Mapping file not found: {mapping_file_path}")

            # Load CSV file
            mapping_df = pd.read_csv(mapping_file_path)

            # Validate required columns
            required_columns = [
                'mapping_id', 'source_name', 'source_object_type',
                'target_name', 'target_object_type', 'validation_mode'
            ]

            missing_columns = [col for col in required_columns if col not in mapping_df.columns]
            if missing_columns:
                raise ValueError(f"Missing required columns: {missing_columns}")

            # Add optional columns if they don't exist
            optional_columns = ['where_clause', 'exclude_columns']
            for col in optional_columns:
                if col not in mapping_df.columns:
                    mapping_df[col] = None

            self.logger.info(f"Loaded {len(mapping_df)} mapping records")
            return mapping_df

        except Exception as e:
            self.logger.error(f"Failed to load mapping file: {str(e)}")
            raise

    def validate_configuration(self) -> bool:
        """Validate that required configuration files exist."""
        try:
            config_dir = 'config'
            required_configs = ['bigquery.yaml', 'snowflake.yaml', 'sqlserver.yaml']

            if not os.path.exists(config_dir):
                self.logger.error(f"Configuration directory not found: {config_dir}")
                return False

            missing_configs = []
            for config_file in required_configs:
                config_path = os.path.join(config_dir, config_file)
                if not os.path.exists(config_path):
                    missing_configs.append(config_file)

            if missing_configs:
                self.logger.error(f"Missing configuration files: {missing_configs}")
                return False

            self.logger.info("Configuration validation passed")
            return True

        except Exception as e:
            self.logger.error(f"Configuration validation failed: {str(e)}")
            return False

    def run_validation(self, mapping_file_path: str, output_file_path: str = None) -> str:
        """
        Run the complete validation process.

        Args:
            mapping_file_path: Path to the mapping CSV file
            output_file_path: Path for output CSV report (optional)

        Returns:
            Path to the generated report file
        """
        try:
            self.logger.info("Starting data validation process")

            # Validate configuration
            if not self.validate_configuration():
                raise RuntimeError("Configuration validation failed")

            # Load mapping file
            mapping_df = self.load_mapping_file(mapping_file_path)

            # Run validation
            self.logger.info("Running validation...")
            validation_results = self.validator.validate(mapping_df)

            # Generate report
            report_path = self.generate_report(validation_results, output_file_path)

            # Log summary
            self._log_validation_summary(validation_results)

            self.logger.info(f"Validation completed. Report saved to: {report_path}")
            return report_path

        except Exception as e:
            self.logger.error(f"Validation process failed: {str(e)}")
            raise

    def generate_report(self, validation_results: List[Dict[str, Any]], 
                       output_file_path: str = None) -> str:
        """Generate CSV validation report."""
        try:
            # Flatten results for CSV output
            report_rows = []

            for result in validation_results:
                base_info = {
                    'mapping_id': result['mapping_id'],
                    'validation_status': result['validation_status'],
                    'remarks': result['remarks'],
                    'source_total_rows': result['source_total_rows'],
                    'target_total_rows': result['target_total_rows']
                }

                if result.get('details'):
                    for detail in result['details']:
                        row = base_info.copy()
                        row.update({
                            'column': detail['column'],
                            'data_type': detail['data_type'],
                            'column_validation_status': detail['validation_status'],
                            'column_remarks': detail['remarks'],
                            'source_aggregates': str(detail['source_aggregates']),
                            'target_aggregates': str(detail['target_aggregates']),
                            'source_normalized': str(detail['source_normalized']),
                            'target_normalized': str(detail['target_normalized'])
                        })
                        report_rows.append(row)
                else:
                    # Add row even if no details
                    row = base_info.copy()
                    row.update({
                        'column': 'N/A',
                        'data_type': 'N/A',
                        'column_validation_status': 'N/A',
                        'column_remarks': 'N/A',
                        'source_aggregates': '{}',
                        'target_aggregates': '{}',
                        'source_normalized': '{}',
                        'target_normalized': '{}'
                    })
                    report_rows.append(row)

            # Create DataFrame and save to CSV
            report_df = pd.DataFrame(report_rows)

            # Generate output file path if not provided
            if not output_file_path:
                timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
                output_file_path = f"validation_report_{timestamp}.csv"

            # Ensure output directory exists
            output_dir = os.path.dirname(output_file_path)
            if output_dir and not os.path.exists(output_dir):
                os.makedirs(output_dir)

            # Save report
            report_df.to_csv(output_file_path, index=False)
            self.logger.info(f"Report saved with {len(report_rows)} rows")

            return output_file_path

        except Exception as e:
            self.logger.error(f"Report generation failed: {str(e)}")
            raise

    def _log_validation_summary(self, validation_results: List[Dict[str, Any]]):
        """Log summary of validation results."""
        try:
            total_validations = len(validation_results)
            passed = sum(1 for r in validation_results if r['validation_status'] == 'PASSED')
            failed = sum(1 for r in validation_results if r['validation_status'] == 'FAILED')
            errors = sum(1 for r in validation_results if r['validation_status'] == 'ERROR')

            self.logger.info("=== VALIDATION SUMMARY ===")
            self.logger.info(f"Total validations: {total_validations}")
            self.logger.info(f"Passed: {passed}")
            self.logger.info(f"Failed: {failed}")
            self.logger.info(f"Errors: {errors}")
            self.logger.info(f"Success rate: {(passed/total_validations)*100:.1f}%")

        except Exception as e:
            self.logger.error(f"Failed to generate summary: {str(e)}")

def main():
    """Main entry point for the validation system."""
    parser = argparse.ArgumentParser(description='Dynamic Data Validation System')
    parser.add_argument('mapping_file', help='Path to the mapping CSV file')
    parser.add_argument('-o', '--output', help='Output CSV report file path')
    parser.add_argument('-l', '--log-level', default='INFO', 
                       choices=['DEBUG', 'INFO', 'WARNING', 'ERROR'],
                       help='Logging level')
    parser.add_argument('--log-file', help='Log file path (optional)')

    args = parser.parse_args()

    try:
        # Initialize validation system
        validation_system = DataValidationSystem(
            log_level=args.log_level,
            log_file=args.log_file
        )

        # Run validation
        report_path = validation_system.run_validation(
            mapping_file_path=args.mapping_file,
            output_file_path=args.output
        )

        print(f"Validation completed successfully!")
        print(f"Report saved to: {report_path}")

    except Exception as e:
        print(f"Validation failed: {str(e)}")
        sys.exit(1)

if __name__ == "__main__":
    main()

# =============================================================================
# Sample Usage and Testing
# =============================================================================

# Sample mapping CSV content (save as mapping/validation_mapping.csv):
sample_mapping_csv = """mapping_id,source_name,source_object_type,target_name,target_object_type,validation_mode,where_clause,exclude_columns
1,project.dataset.source_table,table,project.dataset.target_table,table,column,WHERE active = 1,id
2,"SELECT * FROM project.dataset.custom_source",query,project.dataset.target_table,table,column,,
3,queries/source_query.sql,file,project.dataset.target_table,table,row,,
"""

# Example usage:
"""
# Command line usage:
python main.py mapping/validation_mapping.csv -o reports/validation_report.csv -l INFO

# Programmatic usage:
from main import DataValidationSystem

validation_system = DataValidationSystem(log_level='INFO')
report_path = validation_system.run_validation(
    mapping_file_path='mapping/validation_mapping.csv',
    output_file_path='reports/validation_report.csv'
)
"""

# =============================================================================
# File Structure Summary
# =============================================================================
"""
project_root/
├── main.py                          # Main orchestration script
├── requirements.txt                 # Python dependencies
├── config/                          # Database configuration files
│   ├── bigquery.yaml               # BigQuery SQL templates and connection
│   ├── snowflake.yaml              # Snowflake SQL templates and connection  
│   └── sqlserver.yaml              # SQL Server SQL templates and connection
├── mapping/                         # Mapping definition files
│   └── validation_mapping.csv      # Sample mapping file
├── utils/                           # Utility modules
│   ├── __init__.py                 # Package initializer
│   ├── logger.py                   # Centralized logging utility
│   ├── query_formatter.py          # Safe SQL formatting utility
│   └── database_connector.py       # Database connection abstractions
├── validators/                      # Validation logic modules
│   ├── __init__.py                 # Package initializer
│   ├── base_validator.py           # Abstract base validator
│   └── aggregate_validator.py      # Aggregate validation implementation
├── reports/                         # Generated validation reports
│   └── validation_report_*.csv     # Timestamped report files
└── logs/                           # Log files (optional)
    └── validation_*.log            # Timestamped log files
"""

print("Dynamic Data Validation System - Complete Implementation")
print("=====================================================")
print("This system provides:")
print("1. Modular, extensible architecture following SOLID principles")
print("2. Support for BigQuery, Snowflake, and SQL Server")
print("3. Dynamic query handling with temporary table creation")
print("4. Robust error handling and comprehensive logging")
print("5. Configurable SQL templates via YAML files")
print("6. Safe query formatting with placeholder substitution")
print("7. Aggregate normalization and comparison")
print("8. Detailed CSV validation reports")
print("9. Command-line and programmatic interfaces")
print("10. Comprehensive exception handling and cleanup")
Enter fullscreen mode Exit fullscreen mode

Top comments (0)