javaantlrantlr4

Multiple problems with ANTLR visitor processing


I am doing a pseudo programming language on ANTLR and Java. I met a few problems I can't deal with:

  1. If, else if and else: only one statement is executed per each condition, so others will be skipped. For example: else { show(a); show(b); }. Only a will be shown. return of int was working previously, but now null.

  2. I couldn't manage to make an arrayAccess, meaning by that I can't make this: int b = array[1] or show(array[2]).

  3. Consider the following function and call:

    public string fun(string a) {
        return "Haiii, " + a + "!\n";
    }
    
    string greet = fun("John")
    

    greet remains empty, there is nothing in it.

Here is my GrammarsVisitor.java:

import java.util.HashMap;
import java.util.Map;
import java.util.Stack;
import java.util.List;
import java.util.ArrayList;

public class GrammarsVisitor extends GrammarBaseVisitor<Object> {
    private final Stack<Map<String, Object>> scopes = new Stack<>();
    private final Map<String, FunctionDefinition> functions = new HashMap<>();

    public GrammarsVisitor() {
        // Push the global scope
        scopes.push(new HashMap<>());
    }

    @Override
    public Object visitVariable(GrammarParser.VariableContext ctx) {
        String type = ctx.TYPE().getText();
        String variableName = ctx.ID().getText();
        Object value = null;

        // If the variable is initialized
        if (ctx.expression() != null) {
            value = visitExpression(ctx.expression());
        }

        // Storing the variable in the current scope based on its type
        switch (type) {
            case "int":
                if (value == null) {
                    value = 0;
                } else {
                    value = Integer.parseInt(value.toString());
                }
                break;
            case "float":
                if (value == null) {
                    value = 0.0f;
                } else {
                    value = Float.parseFloat(value.toString());
                }
                break;
            case "string":
                if (value == null) {
                    value = "";
                } else {
                    value = value.toString();
                }
                break;
            case "bool":
                if (value == null) {
                    value = false;
                } else {
                    value = Boolean.parseBoolean(value.toString());
                }
                break;
            default:
                throw new RuntimeException("Unknown type: " + type);
        }

        // Adding the variable to the current scope
        if (scopes.peek().containsKey(variableName)) {
            throw new RuntimeException("Variable already declared in the current scope: " + variableName);
        }
        scopes.peek().put(variableName, value);
        System.out.println("Variable declared: " + type + " " + variableName + " = " + value);
        return null;
    }

    @Override
    public Object visitArray(GrammarParser.ArrayContext ctx) {
        String type = ctx.TYPE() != null ? ctx.TYPE().getText() : ctx.ID(0).getText();
        String arrayName = ctx.ID(0).getText();
        java.util.List<Object> array = new java.util.ArrayList<>();

        // If the array is initialized
        if (ctx.arrayInit() != null) {
            for (GrammarParser.ExpressionContext exprCtx : ctx.arrayInit().expression()) {
                array.add(visitExpression(exprCtx));
            }
        }

        // Adding the array to the current scope
        if (scopes.peek().containsKey(arrayName)) {
            throw new RuntimeException("Array already declared in the current scope: " + arrayName);
        }
        scopes.peek().put(arrayName, array);
        System.out.println("Array declared: " + type + "[] " + arrayName + " = " + array);
        return null;
    }

    public Object visitExpression(GrammarParser.ExpressionContext ctx) {
        if (ctx.functionCall() != null) {
            return visitFunctionCall(ctx.functionCall());
        }
        if (ctx.BOOL() != null) {
            return Boolean.parseBoolean(ctx.BOOL().getText());
        } else if (ctx.INT() != null) {
            return Integer.parseInt(ctx.INT().getText());
        } else if (ctx.FLOAT() != null) {
            return Float.parseFloat(ctx.FLOAT().getText());
        } else if (ctx.STRING() != null) {
            return ctx.STRING().getText()
            .replace("\"", "")
            .replace("\\n", "\n")
            .replace("\\t", "\t");
        } else if (ctx.ID() != null) {
            String variableName = ctx.ID().getText();
            for (int i = scopes.size() - 1; i >= 0; i--) {
                if (scopes.get(i).containsKey(variableName)) {
                    return scopes.get(i).get(variableName);
                }
            }
            throw new RuntimeException("Undefined variable: " + variableName);
        } else if (ctx.getChildCount() == 3 && ctx.getChild(0).getText().equals("(") && ctx.getChild(2).getText().equals(")")) {
            return visitExpression(ctx.expression(0));
        } else if (ctx.getChildCount() == 2 && ctx.getChild(0).getText().equals("!")) {
            return !(Boolean) visitExpression(ctx.expression(0));
        } else if (ctx.getChildCount() == 3) {
            Object left = visitExpression(ctx.expression(0));
            Object right = visitExpression(ctx.expression(1));
            String operator = ctx.getChild(1).getText();

            switch (operator) {
                case "%":
                    if (left instanceof Integer && right instanceof Integer) {
                        return (Integer) left % (Integer) right;
                    } else if (left instanceof Float || right instanceof Float) {
                        return Float.parseFloat(left.toString()) % Float.parseFloat(right.toString());
                    }
                    break;
                case "+":
                    if (left instanceof Integer && right instanceof Integer) {
                        return (Integer) left + (Integer) right;
                    } else if (left instanceof Float || right instanceof Float) {
                        return Float.parseFloat(left.toString()) + Float.parseFloat(right.toString());
                    } else if (left instanceof String || right instanceof String) {
                        return left.toString() + right.toString();
                    }
                    break;
                case "-":
                    if (left instanceof Integer && right instanceof Integer) {
                        return (Integer) left - (Integer) right;
                    } else if (left instanceof Float || right instanceof Float) {
                        return Float.parseFloat(left.toString()) - Float.parseFloat(right.toString());
                    }
                    break;
                case "*":
                    if (left instanceof Integer && right instanceof Integer) {
                        return (Integer) left * (Integer) right;
                    } else if (left instanceof Float || right instanceof Float) {
                        return Float.parseFloat(left.toString()) * Float.parseFloat(right.toString());
                    }
                    break;
                case "/":
                    if (left instanceof Integer && right instanceof Integer) {
                        return (Integer) left / (Integer) right;
                    } else if (left instanceof Float || right instanceof Float) {
                        return Float.parseFloat(left.toString()) / Float.parseFloat(right.toString());
                    }
                    break;
                case "&&":
                    return (Boolean) left && (Boolean) right;
                case "||":
                    return (Boolean) left || (Boolean) right;
                case "==":
                    return left.equals(right);
                case "!=" :
                    return !left.equals(right);
                case "<":
                    if (left instanceof Integer && right instanceof Integer) {
                        return (Integer) left < (Integer) right;
                    } else if (left instanceof Float || right instanceof Float) {
                        return Float.parseFloat(left.toString()) < Float.parseFloat(right.toString());
                    }
                    break;
                case ">":
                    if (left instanceof Integer && right instanceof Integer) {
                        return (Integer) left > (Integer) right;
                    } else if (left instanceof Float || right instanceof Float) {
                        return Float.parseFloat(left.toString()) > Float.parseFloat(right.toString());
                    }
                    break;
                case "<=" :
                    if (left instanceof Integer && right instanceof Integer) {
                        return (Integer) left <= (Integer) right;
                    } else if (left instanceof Float || right instanceof Float) {
                        return Float.parseFloat(left.toString()) <= Float.parseFloat(right.toString());
                    }
                    break;
                case ">=" :
                    if (left instanceof Integer && right instanceof Integer) {
                        return (Integer) left >= (Integer) right;
                    } else if (left instanceof Float || right instanceof Float) {
                        return Float.parseFloat(left.toString()) >= Float.parseFloat(right.toString());
                    }
                    break;
                default:
                    throw new RuntimeException("Unsupported operator: " + operator);
            }
        }
        throw new RuntimeException("Invalid expression: " + ctx.getText());
    }

    @Override
    public Object visitFunction(GrammarParser.FunctionContext ctx) {
        String functionName = ctx.ID(0).getText();
        List<String> paramNames = new ArrayList<>();
        if (ctx.TYPE() != null && ctx.ID().size() > 1) {
            for (int i = 1; i < ctx.ID().size(); i++) {
                paramNames.add(ctx.ID(i).getText());
            }
        }
        functions.put(functionName, new FunctionDefinition(ctx, paramNames));
        System.out.println("Function declared: " + functionName);
        return null;
    }

    @Override
    public Object visitFunctionCall(GrammarParser.FunctionCallContext ctx) {
        String functionName = ctx.ID(0).getText();
        FunctionDefinition functionDef = functions.get(functionName);
        if (functionDef == null) {
            throw new RuntimeException("Undefined function: " + functionName);
        }

        // Creating a new scope for function call
        scopes.push(new HashMap<>());

        // Assign parameter values to new scope
        List<GrammarParser.ExpressionContext> args = ctx.expression();
        for (int i = 0; i < args.size(); i++) {
            String paramName = functionDef.paramNames.get(i);
            Object paramValue = visitExpression(args.get(i));
            scopes.peek().put(paramName, paramValue);
        }

        // Execute the function body
        Object returnValue = null;
        for (GrammarParser.StatementContext stmtCtx : functionDef.functionCtx.statement()) {
            returnValue = visit(stmtCtx);
            if (returnValue != null) {
                break;
            }
        }

        // Pop the function scope
        scopes.pop();
        return returnValue;
    }

    @Override
    public Object visitReturnStatement(GrammarParser.ReturnStatementContext ctx) {
        return visitExpression(ctx.expression());
    }

    @Override
    public Object visitShow(GrammarParser.ShowContext ctx) {
        Object value = visitExpression(ctx.expression());
        System.out.println(value);
        return null;
    }

    @Override
    public Object visitLoop(GrammarParser.LoopContext ctx) {
        String loopVariable = ctx.ID().getText();
        Object start = visitExpression(ctx.expression(0));
        Object end = visitExpression(ctx.expression(1));

        if (!(start instanceof Integer) || !(end instanceof Integer)) {
            throw new RuntimeException("The 'for' loop must have integer start and end values: " + ctx.getText());
        }

        int startValue = (Integer) start;
        int endValue = (Integer) end;

        if (startValue <= endValue) {
            for (int i = startValue; i <= endValue; i++) {
                scopes.peek().put(loopVariable, i);

                for (GrammarParser.StatementContext stmtCtx : ctx.statement()) {
                    visit(stmtCtx);
                }
            }
        } else {
            for (int i = startValue; i >= endValue; i--) {
                scopes.peek().put(loopVariable, i);

                for (GrammarParser.StatementContext stmtCtx : ctx.statement()) {
                    visit(stmtCtx);
                }
            }
        }

        return null;
    }

    @Override
    public Object visitIfStatement(GrammarParser.IfStatementContext ctx) {
        for (int i = 0; i < ctx.expression().size(); i++) {
            Object condition = visitExpression(ctx.expression(i));
            if (!(condition instanceof Boolean)) {
                throw new RuntimeException("The condition must be a boolean expression: " + ctx.expression(i).getText());
            }

            if ((Boolean) condition) {
                visit(ctx.statement(i));
                return null; // Exit after the first true condition block
            }
        }

        // If none of the conditions are true and there's an 'else' block
        if (ctx.statement().size() > ctx.expression().size()) {
            visit(ctx.statement(ctx.expression().size()));
        }

        return null;
    }

    private static class FunctionDefinition {
        GrammarParser.FunctionContext functionCtx;
        List<String> paramNames;

        FunctionDefinition(GrammarParser.FunctionContext functionCtx, List<String> paramNames) {
            this.functionCtx = functionCtx;
            this.paramNames = paramNames;
        }
    }
}

And this is my Grammar.g4 (thanks to Bart Kiers):

grammar Grammar;

start : statement* EOF;

statement
    : variable ';'
    | objectInstance ';'
    | array ';'
    | ifStatement
    | function
    | loop
    | functionCall ';'
    | show ';'
    | arrayAccess ';'
    | class
    | returnStatement ';'
    ;

function
    : ACCESS TYPE ID '(' (TYPE ID (',' TYPE ID)*)? ')' '{' statement* '}'
    | ACCESS 'void' ID '(' (TYPE ID (',' TYPE ID)*)? ')' '{' statement* '}'
    ;

variable : TYPE ID ('=' expression)?;
objectInstance : objectType=ID instanceName=ID ('=' expression)?;
array          : (TYPE | ID) ID '[' ']' ('=' arrayInit)?;
ifStatement    : 'if' '(' expression ')' '{' statement* '}' ('else if' '(' expression ')' '{' statement* '}')* ('else' '{' statement* '}')?;
loop           : 'for' ID 'in' expression 'to' expression '{' statement* '}';
functionCall   : (ID '.')? ID '(' (expression (',' expression)*?)? ')';
show           : 'show' '(' expression ')';
class          : 'class' ID '{' classEntries '}';
classEntries   : ((variable ';') | (array ';') | function)*;
arrayInit      : '[' expression (',' expression)* ']';
returnStatement: 'return' expression;
arrayAccess    : ID '[' expression ']' ;

expression
    : '(' expression ')'
    | '!' expression
    | expression ('&&' | '||') expression
    | expression ('/' | '*' | '%') expression
    | expression ('+' | '-') expression
    | expression COMPARISON expression
    | 'new' ID '(' ((expression | arrayInit) (',' (expression | arrayInit))*)? ')' 
    | functionCall
    | STRING
    | ID
    | INT
    | BOOL
    | FLOAT
    | ID
    ;

ACCESS     : 'private' | 'public';
COMPARISON : '>' | '<' | '>=' | '<=' | '==';
TYPE       : 'int' | 'float' | 'string' | 'bool';
BOOL       : 'true' | 'false' ;
ID         : [a-zA-Z_][a-zA-Z0-9_]* ;
STRING     : '"' (~[\\"] | '\\' .)* '"';
INT        : [0-9]+;
FLOAT      : [0-9]+ '.' [0-9]+;
WS         : [ \t\r\n]+ -> channel(HIDDEN);
LINE_COMMENT : '//' ~[\r\n]* -> skip;
BLOCK_COMMENT : '/*' .*? '*/' -> skip;

Would be very grateful for your help in knowing what are the issues.


Solution

  • You should also override the Object visitStatement(GrammarParser.StatementContext ctx) method in your visitor. If you don't, ANTLR's base visitor will handle it, and will always return the last child, which is null. In general, you should almost always override all parser methods.

    Try adding this:

    @Override
    public Object visitStatement(GrammarParser.StatementContext ctx) {
        if (ctx.variable() != null) return this.visitVariable(ctx.variable());
        if (ctx.objectInstance() != null) return this.visitObjectInstance(ctx.objectInstance());
        if (ctx.array() != null) return this.visitArray(ctx.array());
        if (ctx.ifStatement() != null) return this.visitIfStatement(ctx.ifStatement());
        if (ctx.function() != null) return this.visitFunction(ctx.function());
        if (ctx.loop() != null) return this.visitLoop(ctx.loop());
        if (ctx.functionCall() != null) return this.visitFunction(ctx.function());
        if (ctx.show() != null) return this.visitShow(ctx.show());
        if (ctx.arrayAccess() != null) return this.visitArrayAccess(ctx.arrayAccess());
        if (ctx.class_() != null) return this.visitClass(ctx.class_());
        if (ctx.returnStatement() != null) return this.visitReturnStatement(ctx.returnStatement());
    
        throw new RuntimeException("Unexpected statement: " + ctx.getText());
    }
    

    EDIT

    As Kaby showed in the comments, you don't really need to see if certain contexts are null, just do:

     @Override
    public Object visitStatement(GrammarParser.StatementContext ctx) {
        Object value = visit(ctx.children.get(0));
        return value;
    }
    

    EDIT II

    An array index might work like this:

    expression
        : ...
        | arrayIndex
        | ...
        ;
    
    arrayIndex
        : ID '[' expression ']'
        ;
    
    

    and in your visitor do:

    @Override
    public Object visitExpression(GrammarParser.ExpressionContext ctx) {
        ...
        if (ctx.arrayIndex() != null) {
            return visit(ctx.arrayIndex());
        }
        ...
    }
    
    @Override
    public Object visitArrayIndex(GrammarParser.ArrayIndexContext ctx) {
        ArrayList<Object> array = (ArrayList<Object>)scopes.peek().get(ctx.ID().getText());
        Integer index = (Integer)visit(ctx.expression());
        Object value = array.get(index);
    
        return value;
    }