pythonsqldataframesql-parsersqlglot

Parsing Complex SQL Statements to Map Columns, Tables, Databases Being Queried


I have been trying to create a python script that will parse some pretty complex SQL and return a dataframe of all of the columns, tables, and databases being queried as well as section indicators of where these columns are located in the query like the name of a temp table being created. This includes not only the final columns being selected but also columns in subqueries, select statements in WHERE conditions, join conditions, coalesce statements, etc.

I am having trouble mapping the tables and databases in the WHERE conditions. It is returning the column names from these but the table and database values are coming back null. I am also having trouble only bringing back the actual table names and excluding any aliases.

import sqlglot
from sqlglot import expressions as exp
import pandas as pd

def extract_sql_metadata(sql: str) -> pd.DataFrame:
    parsed_statements = sqlglot.parse(sql)

    columns = []
    seen = set()
    table_registry = {}  # alias -> (table_name, db)
    current_section = None

    def record_column(col_expr: exp.Column):
        nonlocal current_section
        col = col_expr.name
        table_alias = col_expr.table

        resolved_table, db = table_registry.get(table_alias, (table_alias, None))

        key = (col, resolved_table, db, current_section)
        if key not in seen:
            seen.add(key)
            columns.append({
                "column": col,
                "table": resolved_table,
                "database": db,
                "query_section": current_section
            })

    def safe_traverse(val, context):
        if isinstance(val, exp.Expression):
            traverse(val, context)
        elif isinstance(val, list):
            for v in val:
                if isinstance(v, exp.Expression):
                    traverse(v, context)

    def traverse(node, context=None):
        nonlocal current_section

        if not isinstance(node, exp.Expression):
            return

        if isinstance(node, exp.CTE):
            name = node.alias_or_name
            current_section = name
            traverse(node.this, context=name)
            current_section = context

        elif isinstance(node, exp.Subquery):
            alias = node.alias_or_name
            if alias:
                table_registry[alias] = (f"subquery_{alias}", None)
                current_section = alias
            traverse(node.this, context=alias)
            current_section = context

        elif isinstance(node, exp.Table):
            table_name = node.name
            alias = node.alias_or_name or table_name
            db_expr = node.args.get("db")
            db = db_expr.name if isinstance(db_expr, exp.Identifier) else None
            table_registry[alias] = (table_name, db)

        elif isinstance(node, exp.Create):
            table_name = node.this.name
            current_section = table_name
            if node.expression:
                traverse(node.expression, context=table_name)
            current_section = context

        elif isinstance(node, exp.Insert):
            current_section = "final_select"
            traverse(node.expression, context=current_section)
            current_section = context

        elif isinstance(node, exp.Select):
            for proj in node.expressions:
                if isinstance(proj, exp.Alias) and isinstance(proj.this, exp.Column):
                    record_column(proj.this)
                elif isinstance(proj, exp.Column):
                    record_column(proj)

        elif isinstance(node, exp.Column):
            record_column(node)
            return  # avoid recursing into its children again

        # Safely traverse other children
        for key, child in node.args.items():
            # Skip strings or identifiers to avoid str.args error
            if isinstance(child, (exp.Expression, list)):
                safe_traverse(child, context)

    for stmt in parsed_statements:
        traverse(stmt)

    return pd.DataFrame(columns)

Sample of a complex SQL statement to be parsed

CREATE TABLE TRD AS (
SELECT 
   TR.REQUEST_ID
   ,P17.THIS_WORKING
   ,P17.REQUEST_FIELD_VAL AS "AUTHORIZATION"
   ,P20.REQUEST_FIELD_VAL AS "CONTRACT PD/AOR"
FROM ADW_VIEWS_FSICS.FSICS_IPSS_TRAINING_REQUESTS TR
LEFT JOIN ADW_VIEWS_FSICS.FSICS_IPSS_TRNG_REQUESTS_DET P17 
   ON TR.REQUEST_ID = P17.REQUEST_ID 
   AND P17.REQUEST_FIELD_EXTERNAL_ID = 'IPSS_MD_PROPERTY_17'
LEFT JOIN ADW_VIEWS_FSICS.FSICS_IPSS_TRNG_REQUESTS_DET P20 
   ON TR.REQUEST_ID = P20.REQUEST_ID 
   AND P20.REQUEST_FIELD_EXTERNAL_ID = 'IPSS_MD_PROPERTY_20'
WHERE TR.REQUEST_ID IN (
   SELECT REQUEST_ID
   FROM ADW_VIEWS_FSICS.MY_TNG_REQUESTS
   WHERE EVENT_TYPE = 'BASIC'
   )
); 

Given the above function and example SQL, I would want to get the below results.

column table database query_section
REQUEST_ID FSICS_IPSS_TRAINING_REQUESTS ADW_VIEWS_FSICS TRD
THIS_WORKING FSICS_IPSS_TRNG_REQUESTS_DET ADW_VIEWS_FSICS TRD
REQUEST_FIELD_VAL FSICS_IPSS_TRNG_REQUESTS_DET ADW_VIEWS_FSICS TRD
REQUEST_ID FSICS_IPSS_TRNG_REQUESTS_DET ADW_VIEWS_FSICS TRD
REQUEST_FIELD_EXTERNAL_ID FSICS_IPSS_TRNG_REQUESTS_DET ADW_VIEWS_FSICS TRD
REQUEST_ID MY_TNG_REQUESTS ADW_VIEWS_FSICS TRD
EVENT_TYPE MY_TNG_REQUESTS ADW_VIEWS_FSICS TRD

Solution

  • I was way over thinking it. Just in case anyone else stumbles across this, here is how I did it:

    from sqlglot.optimizer.qualify_columns  import qualify_columns
    from sqlglot.optimizer.scope            import traverse_scope
    from sqlglot                            import parse, exp
    
    def parse_query(sql_query, dialect='tsql'):
    
        df_list = []
        for ast in parse(sql_query, read=dialect):
            ast = qualify_columns(ast, schema=None)
            section = str(ast.this).upper() if ast.this else str(ast.key).upper()
            physical_columns = []
            for scope in traverse_scope(ast):
                for c in scope.columns:
                    table = scope.sources.get(c.table)
                    if isinstance(scope.sources.get(c.table), exp.Table):
                        database_name = table.db if hasattr(table, 'db') else None
                        physical_columns.append((section, database_name, table.name, c.name))
                    else:
                        physical_columns.append((section, None, None, c.name))
    
            df = pd.DataFrame(physical_columns, columns=['section', 'database', 'table', 'columns'])
            df = df.drop_duplicates()
            df_list.append(df)
    
        return pd.concat(df_list, ignore_index=True)