Commit e9f4108d authored by Roman Alifanov's avatar Roman Alifanov

Fix comparison codegen, env var safety, escaped braces, DCE class reachability

- Fix numeric/string comparison operator selection in expr and stmt codegen - Use ${VAR:-} for env var reads to prevent unbound variable errors - Add {{ and }} as literal brace escapes in string interpolation - Fix DCE to discover transitive class dependencies through method callees - Fix &&/|| condition codegen to prevent __CT_RET variable conflicts
parent 3cfa569f
...@@ -38,6 +38,7 @@ _CMP_OPS = {'==', '!=', '<', '>', '<=', '>='} ...@@ -38,6 +38,7 @@ _CMP_OPS = {'==', '!=', '<', '>', '<=', '>='}
_LOGIC_OPS = {'&&', '||'} _LOGIC_OPS = {'&&', '||'}
_ARITH_OPS = {'+', '-', '*', '/', '%', '**'} _ARITH_OPS = {'+', '-', '*', '/', '%', '**'}
_BASH_CMP = {'==': '==', '!=': '!=', '<': '-lt', '>': '-gt', '<=': '-le', '>=': '-ge'} _BASH_CMP = {'==': '==', '!=': '!=', '<': '-lt', '>': '-gt', '<=': '-le', '>=': '-ge'}
_ARITH_CMP = {'==': '==', '!=': '!=', '<': '<', '>': '>', '<=': '<=', '>=': '>='}
_STR_CMP = {'==': '==', '!=': '!=', '<': '<', '>': '>'} _STR_CMP = {'==': '==', '!=': '!=', '<': '<', '>': '>'}
# Stdlib method → prefix # Stdlib method → prefix
...@@ -295,10 +296,9 @@ def _identifier(node: IRIdentifier, ctx: 'EmitContext') -> str: ...@@ -295,10 +296,9 @@ def _identifier(node: IRIdentifier, ctx: 'EmitContext') -> str:
if node.symbol and node.symbol.kind == 'func': if node.symbol and node.symbol.kind == 'func':
return f'"{node.symbol.bash_name()}"' return f'"{node.symbol.bash_name()}"'
name = _var_name(node) name = _var_name(node)
# env.VAR read
if name.startswith('env.'): if name.startswith('env.'):
env_key = name[4:] env_key = name[4:]
return f'"${{{env_key}}}"' return f'"${{{env_key}:-}}"'
return f'"${{{name}}}"' return f'"${{{name}}}"'
...@@ -371,12 +371,18 @@ def _binary(node: IRBinaryOp, ctx: 'EmitContext') -> str: ...@@ -371,12 +371,18 @@ def _binary(node: IRBinaryOp, ctx: 'EmitContext') -> str:
# Comparison # Comparison
if op in _CMP_OPS: if op in _CMP_OPS:
if lt.kind in ('int', 'float') and rt.kind in ('int', 'float'): is_numeric = lt.kind in ('int', 'float') or rt.kind in ('int', 'float')
is_string = lt.kind == 'string' and rt.kind == 'string'
if op in ('<', '>', '<=', '>=') and not is_string:
la = _to_arith(lv) la = _to_arith(lv)
ra = _to_arith(rv) ra = _to_arith(rv)
bash_op = _BASH_CMP[op] arith_op = _ARITH_CMP[op]
return f'$([[ $(({la} {bash_op} {ra})) -ne 0 ]] && echo true || echo false)' return f'$([[ $(({la} {arith_op} {ra})) -ne 0 ]] && echo true || echo false)'
# string comparison if is_numeric:
la = _to_arith(lv)
ra = _to_arith(rv)
arith_op = _ARITH_CMP[op]
return f'$([[ $(({la} {arith_op} {ra})) -ne 0 ]] && echo true || echo false)'
str_op = _STR_CMP.get(op, op) str_op = _STR_CMP.get(op, op)
return f'$([[ {lv} {str_op} {rv} ]] && echo true || echo false)' return f'$([[ {lv} {str_op} {rv} ]] && echo true || echo false)'
......
...@@ -383,24 +383,38 @@ def _condition_bash(cond, ctx: 'EmitContext') -> str: ...@@ -383,24 +383,38 @@ def _condition_bash(cond, ctx: 'EmitContext') -> str:
if op in ('==', '!=', '<', '>', '<=', '>='): if op in ('==', '!=', '<', '>', '<=', '>='):
lv = expr_(cond.left, ctx) lv = expr_(cond.left, ctx)
rv = expr_(cond.right, ctx) rv = expr_(cond.right, ctx)
# Numeric comparison is_numeric = lt.kind in ('int', 'float') or rt.kind in ('int', 'float')
if lt.kind in ('int', 'float') or rt.kind in ('int', 'float'): is_string = lt.kind == 'string' and rt.kind == 'string'
if op in ('<', '>', '<=', '>='):
if is_string:
str_op = {'<': '<', '>': '>'}.get(op, op)
if op in ('<=', '>='):
la = _to_arith(lv)
ra = _to_arith(rv)
bash_op = _BASH_CMP_NUM[op]
return f'[[ "{la}" {bash_op} "{ra}" ]]'
return f'[[ {lv} {str_op} {rv} ]]'
la = _to_arith(lv) la = _to_arith(lv)
ra = _to_arith(rv) ra = _to_arith(rv)
bash_op = _BASH_CMP_NUM[op] bash_op = _BASH_CMP_NUM[op]
return f'[[ "{la}" {bash_op} "{ra}" ]]' return f'[[ "{la}" {bash_op} "{ra}" ]]'
# String comparison if is_numeric:
str_op = {'==': '==', '!=': '!=', '<': '<', '>': '>', '<=': '<=', '>=': '>='}.get(op, op) la = _to_arith(lv)
ra = _to_arith(rv)
bash_op = _BASH_CMP_NUM[op]
return f'[[ "{la}" {bash_op} "{ra}" ]]'
str_op = {'==': '==', '!=': '!='}.get(op, op)
return f'[[ {lv} {str_op} {rv} ]]' return f'[[ {lv} {str_op} {rv} ]]'
if op == '&&': if op in ('&&', '||'):
lc = _condition_bash(cond.left, ctx) saved_len = len(ctx._output)
rc = _condition_bash(cond.right, ctx)
return f'{lc} && {rc}'
if op == '||':
lc = _condition_bash(cond.left, ctx) lc = _condition_bash(cond.left, ctx)
if len(ctx._output) > saved_len and '__CT_RET' in lc:
tmp = ctx.fresh_tmp()
ctx.emit(f'{tmp}="${{__CT_RET}}"')
lc = lc.replace('${__CT_RET}', f'${{{tmp}}}')
rc = _condition_bash(cond.right, ctx) rc = _condition_bash(cond.right, ctx)
return f'{lc} || {rc}' return f'{lc} {op} {rc}'
# Logical negation # Logical negation
if isinstance(cond, IRUnaryOp) and cond.operator == '!': if isinstance(cond, IRUnaryOp) and cond.operator == '!':
......
...@@ -140,6 +140,12 @@ class Lexer: ...@@ -140,6 +140,12 @@ class Lexer:
self._advance() self._advance()
continue continue
if ch == '{': if ch == '{':
# {{ is a literal '{' (escaped brace)
if self._ch(1) == '{':
text_buf.append('{')
self._advance() # consume first '{'
self._advance() # consume second '{'
continue
# Begin interpolation # Begin interpolation
flush_text() flush_text()
interp_line, interp_col = self.line, self.column interp_line, interp_col = self.line, self.column
...@@ -167,6 +173,12 @@ class Lexer: ...@@ -167,6 +173,12 @@ class Lexer:
col=interp_col, col=interp_col,
)) ))
continue continue
# }} is a literal '}' (escaped brace)
if ch == '}' and self._ch(1) == '}':
text_buf.append('}')
self._advance() # consume first '}'
self._advance() # consume second '}'
continue
text_buf.append(self._advance()) text_buf.append(self._advance())
flush_text() flush_text()
......
...@@ -92,6 +92,31 @@ def run_dce(ir: IRProgram, call_graph: CallGraph, ...@@ -92,6 +92,31 @@ def run_dce(ir: IRProgram, call_graph: CallGraph,
for stmt in ir.top_stmts: for stmt in ir.top_stmts:
_collect_class_refs(stmt, reachable_class_names) _collect_class_refs(stmt, reachable_class_names)
# Expand reachability through methods of reachable classes:
# methods are kept by class name, but their transitive callees
# (e.g. new OtherClass()) need to be discovered via BFS.
cls_methods = {cl.name: cl.methods for cl in ir.classes}
changed = True
while changed:
changed = False
method_syms: set[Symbol] = set()
for cname in reachable_class_names:
for m in cls_methods.get(cname, []):
if m.symbol and m.symbol not in reachable:
method_syms.add(m.symbol)
if method_syms:
extra = call_graph.reachable_from(method_syms)
extra -= reachable
if extra:
reachable.update(extra)
reachable.update(method_syms)
for sym in extra:
if sym.kind == 'class' and sym.name not in reachable_class_names:
reachable_class_names.add(sym.name)
changed = True
else:
reachable.update(method_syms)
# Keep parent classes of reachable classes (needed for inherited method aliases) # Keep parent classes of reachable classes (needed for inherited method aliases)
cls_by_name = {cl.name: cl for cl in ir.classes} cls_by_name = {cl.name: cl for cl in ir.classes}
worklist = list(reachable_class_names) worklist = list(reachable_class_names)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment