I have been trying to create a python script that will parse some pretty complex SQL and return a dataframe of all of the columns, tables, and databases being queried as well as section indicators of where these columns are located in the query like the name of a temp table being created. This includes not only the final columns being selected but also columns in subqueries, select statements in WHERE conditions, join conditions, coalesce statements, etc.
I am having trouble mapping the tables and databases in the WHERE conditions. It is returning the column names from these but the table and database values are coming back null. I am also having trouble only bringing back the actual table names and excluding any aliases.
import sqlglot
from sqlglot import expressions as exp
import pandas as pd
def extract_sql_metadata(sql: str) -> pd.DataFrame:
parsed_statements = sqlglot.parse(sql)
columns = []
seen = set()
table_registry = {} # alias -> (table_name, db)
current_section = None
def record_column(col_expr: exp.Column):
nonlocal current_section
col = col_expr.name
table_alias = col_expr.table
resolved_table, db = table_registry.get(table_alias, (table_alias, None))
key = (col, resolved_table, db, current_section)
if key not in seen:
seen.add(key)
columns.append({
"column": col,
"table": resolved_table,
"database": db,
"query_section": current_section
})
def safe_traverse(val, context):
if isinstance(val, exp.Expression):
traverse(val, context)
elif isinstance(val, list):
for v in val:
if isinstance(v, exp.Expression):
traverse(v, context)
def traverse(node, context=None):
nonlocal current_section
if not isinstance(node, exp.Expression):
return
if isinstance(node, exp.CTE):
name = node.alias_or_name
current_section = name
traverse(node.this, context=name)
current_section = context
elif isinstance(node, exp.Subquery):
alias = node.alias_or_name
if alias:
table_registry[alias] = (f"subquery_{alias}", None)
current_section = alias
traverse(node.this, context=alias)
current_section = context
elif isinstance(node, exp.Table):
table_name = node.name
alias = node.alias_or_name or table_name
db_expr = node.args.get("db")
db = db_expr.name if isinstance(db_expr, exp.Identifier) else None
table_registry[alias] = (table_name, db)
elif isinstance(node, exp.Create):
table_name = node.this.name
current_section = table_name
if node.expression:
traverse(node.expression, context=table_name)
current_section = context
elif isinstance(node, exp.Insert):
current_section = "final_select"
traverse(node.expression, context=current_section)
current_section = context
elif isinstance(node, exp.Select):
for proj in node.expressions:
if isinstance(proj, exp.Alias) and isinstance(proj.this, exp.Column):
record_column(proj.this)
elif isinstance(proj, exp.Column):
record_column(proj)
elif isinstance(node, exp.Column):
record_column(node)
return # avoid recursing into its children again
# Safely traverse other children
for key, child in node.args.items():
# Skip strings or identifiers to avoid str.args error
if isinstance(child, (exp.Expression, list)):
safe_traverse(child, context)
for stmt in parsed_statements:
traverse(stmt)
return pd.DataFrame(columns)
Sample of a complex SQL statement to be parsed
CREATE TABLE TRD AS (
SELECT
TR.REQUEST_ID
,P17.THIS_WORKING
,P17.REQUEST_FIELD_VAL AS "AUTHORIZATION"
,P20.REQUEST_FIELD_VAL AS "CONTRACT PD/AOR"
FROM ADW_VIEWS_FSICS.FSICS_IPSS_TRAINING_REQUESTS TR
LEFT JOIN ADW_VIEWS_FSICS.FSICS_IPSS_TRNG_REQUESTS_DET P17
ON TR.REQUEST_ID = P17.REQUEST_ID
AND P17.REQUEST_FIELD_EXTERNAL_ID = 'IPSS_MD_PROPERTY_17'
LEFT JOIN ADW_VIEWS_FSICS.FSICS_IPSS_TRNG_REQUESTS_DET P20
ON TR.REQUEST_ID = P20.REQUEST_ID
AND P20.REQUEST_FIELD_EXTERNAL_ID = 'IPSS_MD_PROPERTY_20'
WHERE TR.REQUEST_ID IN (
SELECT REQUEST_ID
FROM ADW_VIEWS_FSICS.MY_TNG_REQUESTS
WHERE EVENT_TYPE = 'BASIC'
)
);
Given the above function and example SQL, I would want to get the below results.
column | table | database | query_section |
---|---|---|---|
REQUEST_ID | FSICS_IPSS_TRAINING_REQUESTS | ADW_VIEWS_FSICS | TRD |
THIS_WORKING | FSICS_IPSS_TRNG_REQUESTS_DET | ADW_VIEWS_FSICS | TRD |
REQUEST_FIELD_VAL | FSICS_IPSS_TRNG_REQUESTS_DET | ADW_VIEWS_FSICS | TRD |
REQUEST_ID | FSICS_IPSS_TRNG_REQUESTS_DET | ADW_VIEWS_FSICS | TRD |
REQUEST_FIELD_EXTERNAL_ID | FSICS_IPSS_TRNG_REQUESTS_DET | ADW_VIEWS_FSICS | TRD |
REQUEST_ID | MY_TNG_REQUESTS | ADW_VIEWS_FSICS | TRD |
EVENT_TYPE | MY_TNG_REQUESTS | ADW_VIEWS_FSICS | TRD |
I was way over thinking it. Just in case anyone else stumbles across this, here is how I did it:
from sqlglot.optimizer.qualify_columns import qualify_columns
from sqlglot.optimizer.scope import traverse_scope
from sqlglot import parse, exp
def parse_query(sql_query, dialect='tsql'):
df_list = []
for ast in parse(sql_query, read=dialect):
ast = qualify_columns(ast, schema=None)
section = str(ast.this).upper() if ast.this else str(ast.key).upper()
physical_columns = []
for scope in traverse_scope(ast):
for c in scope.columns:
table = scope.sources.get(c.table)
if isinstance(scope.sources.get(c.table), exp.Table):
database_name = table.db if hasattr(table, 'db') else None
physical_columns.append((section, database_name, table.name, c.name))
else:
physical_columns.append((section, None, None, c.name))
df = pd.DataFrame(physical_columns, columns=['section', 'database', 'table', 'columns'])
df = df.drop_duplicates()
df_list.append(df)
return pd.concat(df_list, ignore_index=True)