Commit 0e7f2d60 authored by Roman Alifanov's avatar Roman Alifanov

Improve parameter passing for arrays, dicts and class instances

- Add proper type tracking for function and method parameters - DCE now tracks array-returning methods (keys, split, slice) - Fix split/slice to return arrays via __CT_RET_ARR - Add foreach support for inline split expressions - Detect object parameters by unknown method calls - Add comprehensive tests for parameter passing New test classes: - TestFunctionParameterPassing (11 tests) - TestClassInstancePassing (4 tests)
parent a9fe7b9d
...@@ -1010,12 +1010,35 @@ Error: Unknown method 'badMethod' for type 'fs'. Available: append, exists, list ...@@ -1010,12 +1010,35 @@ Error: Unknown method 'badMethod' for type 'fs'. Available: append, exists, list
``` ```
CodeGenerator CodeGenerator
├── StdlibMixin # stdlib.py — встроенные функции
├── AwkCodegenMixin # awk_codegen.py — @awk компиляция
├── ExprMixin # expr_codegen.py — выражения ├── ExprMixin # expr_codegen.py — выражения
├── StmtMixin # stmt_codegen.py — statements ├── StmtMixin # stmt_codegen.py — statements
├── ClassMixin # class_codegen.py — классы/методы ├── ClassMixin # class_codegen.py — классы/методы
├── DecoratorMixin # decorator_codegen.py — декораторы ├── DecoratorMixin # decorator_codegen.py — декораторы
├── DispatchMixin # dispatch_codegen.py — диспатч, присваивания ├── DispatchMixin # dispatch_codegen.py — диспатч, присваивания
└── CseMixin # cse_codegen.py — CSE оптимизации ├── CseMixin # cse_codegen.py — CSE оптимизации
├── StdlibMixin # stdlib.py — встроенные функции
└── AwkCodegenMixin # awk_codegen.py — @awk компиляция
Вспомогательные модули:
├── constants.py # Константы (RET_VAR, TMP_PREFIX, CLASS_FUNC_PREFIX, etc.)
└── methods.py # Единый реестр методов для bash/awk синхронизации
``` ```
### Добавление новых методов
Для добавления нового метода достаточно обновить `methods.py`:
```python
STRING_METHODS = {
...
"new_method": MethodDef(
"new_method",
min_args=1,
max_args=1,
bash_func="__ct_str_new_method",
awk_gen=lambda obj, args: f"awk_impl({obj}, {args[0]})"
),
}
```
Bash и AWK codegen автоматически подхватят изменения.
...@@ -358,33 +358,35 @@ Uses `json.get()` for parsing Telegram API responses and `str.urlencode()` for U ...@@ -358,33 +358,35 @@ Uses `json.get()` for parsing Telegram API responses and `str.urlencode()` for U
## Project Structure ## Project Structure
``` ```
bootstrap/ # Bootstrap compiler (Python) bootstrap/ # Bootstrap compiler (Python)
├── main.py # CLI entry point ├── main.py # CLI entry point
├── lexer.py # Tokenizer ├── lexer.py # Tokenizer
├── tokens.py # Token type definitions ├── tokens.py # Token type definitions
├── parser.py # Recursive descent parser, AST generation ├── parser.py # Recursive descent parser, AST generation
├── ast_nodes.py # AST node classes ├── ast_nodes.py # AST node classes
├── codegen.py # Main Bash code generator ├── errors.py # Error handling
├── stmt_codegen.py # Statement generation (mixin) ├── constants.py # Codegen constants (RET_VAR, TMP_PREFIX, etc.)
├── expr_codegen.py # Expression generation (mixin) ├── methods.py # Unified method registry for bash/awk sync
├── class_codegen.py # Class/method generation (mixin) ├── dce.py # Dead code elimination
├── codegen.py # Main Bash code generator (mixin coordinator)
├── expr_codegen.py # Expression generation (mixin)
├── stmt_codegen.py # Statement generation (mixin)
├── class_codegen.py # Class/method generation (mixin)
├── dispatch_codegen.py # Method dispatch, assignments (mixin) ├── dispatch_codegen.py # Method dispatch, assignments (mixin)
├── decorator_codegen.py # Decorator wrappers (mixin) ├── decorator_codegen.py # Decorator wrappers (mixin)
├── awk_codegen.py # AWK generator for @awk (mixin) ├── cse_codegen.py # Common subexpression elimination (mixin)
├── stdlib.py # Standard library generation (mixin) ├── stdlib.py # Standard library generation (mixin)
├── cse_codegen.py # Common subexpression elimination (mixin) └── awk_codegen.py # AWK generator for @awk (mixin)
├── dce.py # Dead code elimination
└── errors.py # Error handling lib/ # ContenT libraries
└── cli.ct # CLI library (urfave/cli style)
lib/ # ContenT libraries
└── cli.ct # CLI library (urfave/cli style) tests/ # Test suite
├── test_lexer.py # Lexer tests
tests/ # Test suite ├── test_parser.py # Parser tests
├── test_lexer.py # Lexer tests
├── test_parser.py # Parser tests
└── test_integration.py # Integration tests └── test_integration.py # Integration tests
examples/ # Example .ct programs examples/ # Example .ct programs
``` ```
## License ## License
......
...@@ -358,25 +358,27 @@ python3 content run examples/telegram_echobot/echobot.ct ...@@ -358,25 +358,27 @@ python3 content run examples/telegram_echobot/echobot.ct
## Структура проекта ## Структура проекта
``` ```
bootstrap/ # Bootstrap-компилятор (Python) bootstrap/ # Bootstrap-компилятор (Python)
├── main.py # CLI точка входа ├── main.py # CLI точка входа
├── lexer.py # Токенизатор ├── lexer.py # Токенизатор
├── tokens.py # Определения типов токенов ├── tokens.py # Определения типов токенов
├── parser.py # Рекурсивный спуск, генерация AST ├── parser.py # Рекурсивный спуск, генерация AST
├── ast_nodes.py # Классы узлов AST ├── ast_nodes.py # Классы узлов AST
├── codegen.py # Основной генератор Bash-кода ├── errors.py # Обработка ошибок
├── stmt_codegen.py # Генерация statements (миксин) ├── constants.py # Константы кодогенерации (RET_VAR, TMP_PREFIX, etc.)
├── expr_codegen.py # Генерация выражений (миксин) ├── methods.py # Единый реестр методов для bash/awk синхронизации
├── class_codegen.py # Генерация классов/методов (миксин) ├── dce.py # Устранение мёртвого кода
├── codegen.py # Основной генератор Bash-кода (координатор миксинов)
├── expr_codegen.py # Генерация выражений (миксин)
├── stmt_codegen.py # Генерация statements (миксин)
├── class_codegen.py # Генерация классов/методов (миксин)
├── dispatch_codegen.py # Диспатч методов, присваивания (миксин) ├── dispatch_codegen.py # Диспатч методов, присваивания (миксин)
├── decorator_codegen.py # Обёртки декораторов (миксин) ├── decorator_codegen.py # Обёртки декораторов (миксин)
├── awk_codegen.py # AWK-генератор для @awk (миксин) ├── cse_codegen.py # Устранение общих подвыражений (миксин)
├── stdlib.py # Генерация стандартной библиотеки (миксин) ├── stdlib.py # Генерация стандартной библиотеки (миксин)
├── cse_codegen.py # Устранение общих подвыражений (миксин) └── awk_codegen.py # AWK-генератор для @awk (миксин)
├── dce.py # Устранение мёртвого кода
└── errors.py # Обработка ошибок
lib/ # Библиотеки на ContenT lib/ # Библиотеки на ContenT
└── cli.ct # CLI-библиотека (стиль urfave/cli) └── cli.ct # CLI-библиотека (стиль urfave/cli)
tests/ # Тестовый набор tests/ # Тестовый набор
......
from .ast_nodes import * from .ast_nodes import (
ClassDecl, FunctionDecl, ArrayLiteral, DictLiteral, NilLiteral, NewExpr,
CallExpr, Identifier, Assignment, MemberAccess, ThisExpr, ReturnStmt,
ConstructorDecl, Parameter, Block, ForeachStmt, IfStmt, WhileStmt, ForStmt,
ExpressionStmt, BinaryOp, IndexAccess
)
from .methods import ARRAY_METHODS, DICT_METHODS
ARRAY_ONLY_METHODS = {"push", "pop", "shift", "join", "slice", "map", "filter"}
DICT_ONLY_METHODS = {"has", "del", "keys"}
ARRAY_METHODS_ALL = ARRAY_ONLY_METHODS | {"get", "set", "len"}
DICT_METHODS_ALL = DICT_ONLY_METHODS | {"get", "set", "len"}
STRING_METHODS_ALL = {"upper", "lower", "trim", "len", "contains", "starts", "ends",
"index", "replace", "substr", "split", "charAt", "urlencode"}
ALL_KNOWN_METHODS = ARRAY_METHODS_ALL | DICT_METHODS_ALL | STRING_METHODS_ALL
class ClassMixin: class ClassMixin:
...@@ -176,9 +191,11 @@ class ClassMixin: ...@@ -176,9 +191,11 @@ class ClassMixin:
self.in_class_method = True self.in_class_method = True
old_in_function = self.in_function old_in_function = self.in_function
old_local_vars = self.local_vars.copy() old_local_vars = self.local_vars.copy()
old_object_vars = self.object_vars.copy()
self.in_function = True self.in_function = True
self.local_vars = set() self.local_vars = set()
param_types = self._analyze_param_types(method)
for i, param in enumerate(method.params): for i, param in enumerate(method.params):
if param.is_variadic: if param.is_variadic:
self.emit(f'local -a {param.name}=("${{@:{i + 1}}}")') self.emit(f'local -a {param.name}=("${{@:{i + 1}}}")')
...@@ -189,6 +206,8 @@ class ClassMixin: ...@@ -189,6 +206,8 @@ class ClassMixin:
else: else:
self.emit(f'local {param.name}="${{{i + 1}}}"') self.emit(f'local {param.name}="${{{i + 1}}}"')
self.local_vars.add(param.name) self.local_vars.add(param.name)
if param_types.get(param.name) == "object":
self.object_vars.add(param.name)
for stmt in method.body.statements: for stmt in method.body.statements:
self.generate_statement(stmt) self.generate_statement(stmt)
...@@ -196,6 +215,7 @@ class ClassMixin: ...@@ -196,6 +215,7 @@ class ClassMixin:
self.in_class_method = False self.in_class_method = False
self.in_function = old_in_function self.in_function = old_in_function
self.local_vars = old_local_vars self.local_vars = old_local_vars
self.object_vars = old_object_vars
self.indent_level -= 1 self.indent_level -= 1
self.emit("}") self.emit("}")
self.emit() self.emit()
...@@ -222,9 +242,11 @@ class ClassMixin: ...@@ -222,9 +242,11 @@ class ClassMixin:
self.in_class_method = True self.in_class_method = True
old_in_function = self.in_function old_in_function = self.in_function
old_local_vars = self.local_vars.copy() old_local_vars = self.local_vars.copy()
old_object_vars = self.object_vars.copy()
self.in_function = True self.in_function = True
self.local_vars = set() self.local_vars = set()
param_types = self._analyze_param_types(method)
for i, param in enumerate(method.params): for i, param in enumerate(method.params):
if param.is_variadic: if param.is_variadic:
self.emit(f'local -a {param.name}=("${{@:{i + 1}}}")') self.emit(f'local -a {param.name}=("${{@:{i + 1}}}")')
...@@ -235,6 +257,8 @@ class ClassMixin: ...@@ -235,6 +257,8 @@ class ClassMixin:
else: else:
self.emit(f'local {param.name}="${{{i + 1}}}"') self.emit(f'local {param.name}="${{{i + 1}}}"')
self.local_vars.add(param.name) self.local_vars.add(param.name)
if param_types.get(param.name) == "object":
self.object_vars.add(param.name)
for stmt in method.body.statements: for stmt in method.body.statements:
self.generate_statement(stmt) self.generate_statement(stmt)
...@@ -242,6 +266,7 @@ class ClassMixin: ...@@ -242,6 +266,7 @@ class ClassMixin:
self.in_class_method = False self.in_class_method = False
self.in_function = old_in_function self.in_function = old_in_function
self.local_vars = old_local_vars self.local_vars = old_local_vars
self.object_vars = old_object_vars
self.indent_level -= 1 self.indent_level -= 1
self.emit("}") self.emit("}")
self.emit() self.emit()
...@@ -313,6 +338,86 @@ class ClassMixin: ...@@ -313,6 +338,86 @@ class ClassMixin:
self.inlineable_methods[(cls.name, method.name)] = \ self.inlineable_methods[(cls.name, method.name)] = \
f'${{__CT_OBJ["$this.{arg0.member}"]:${{__CT_OBJ["$this.{arg1.member}"]}}:1}}' f'${{__CT_OBJ["$this.{arg0.member}"]:${{__CT_OBJ["$this.{arg1.member}"]}}:1}}'
def _analyze_param_types(self, func: FunctionDecl) -> dict:
"""Analyze function body to determine parameter types (array/dict/scalar)."""
param_names = {p.name for p in func.params}
param_types = {p.name: "scalar" for p in func.params}
param_methods = {p.name: set() for p in func.params}
def analyze_expr(expr):
if isinstance(expr, CallExpr) and isinstance(expr.callee, MemberAccess):
if isinstance(expr.callee.object, Identifier):
var_name = expr.callee.object.name
method = expr.callee.member
if var_name in param_names:
param_methods[var_name].add(method)
for arg in expr.arguments:
analyze_expr(arg)
elif isinstance(expr, BinaryOp):
analyze_expr(expr.left)
analyze_expr(expr.right)
elif isinstance(expr, IndexAccess):
if isinstance(expr.object, Identifier):
var_name = expr.object.name
if var_name in param_names and param_types[var_name] == "scalar":
param_types[var_name] = "array"
analyze_expr(expr.index)
def analyze_stmt(stmt):
if isinstance(stmt, Assignment):
analyze_expr(stmt.value)
if isinstance(stmt.target, IndexAccess):
if isinstance(stmt.target.object, Identifier):
var_name = stmt.target.object.name
if var_name in param_names:
param_types[var_name] = "array"
elif isinstance(stmt, ExpressionStmt):
analyze_expr(stmt.expression)
elif isinstance(stmt, ForeachStmt):
if isinstance(stmt.iterable, Identifier):
var_name = stmt.iterable.name
if var_name in param_names:
param_types[var_name] = "array"
if stmt.body:
for s in stmt.body.statements:
analyze_stmt(s)
elif isinstance(stmt, (IfStmt,)):
analyze_expr(stmt.condition)
if stmt.then_branch:
for s in stmt.then_branch.statements:
analyze_stmt(s)
for _, branch in stmt.elif_branches:
for s in branch.statements:
analyze_stmt(s)
if stmt.else_branch:
for s in stmt.else_branch.statements:
analyze_stmt(s)
elif isinstance(stmt, (WhileStmt, ForStmt)):
if hasattr(stmt, 'condition'):
analyze_expr(stmt.condition)
if stmt.body:
for s in stmt.body.statements:
analyze_stmt(s)
elif isinstance(stmt, ReturnStmt) and stmt.value:
analyze_expr(stmt.value)
if func.body:
for stmt in func.body.statements:
analyze_stmt(stmt)
for param_name, methods in param_methods.items():
unknown_methods = methods - ALL_KNOWN_METHODS
if unknown_methods:
param_types[param_name] = "object"
continue
if methods & ARRAY_ONLY_METHODS:
param_types[param_name] = "array"
elif methods & DICT_ONLY_METHODS:
param_types[param_name] = "dict"
elif methods & ARRAY_METHODS_ALL and not (methods & DICT_METHODS_ALL - {"get", "set", "len"}):
param_types[param_name] = "array"
return param_types
def generate_function(self, func: FunctionDecl): def generate_function(self, func: FunctionDecl):
test_decorator = None test_decorator = None
...@@ -357,9 +462,23 @@ class ClassMixin: ...@@ -357,9 +462,23 @@ class ClassMixin:
self.emit(f"{name} () {{") self.emit(f"{name} () {{")
self.indent_level += 1 self.indent_level += 1
param_types = self._analyze_param_types(func)
old_param_name_map = getattr(self, 'param_name_map', {})
self.param_name_map = {}
for i, param in enumerate(func.params): for i, param in enumerate(func.params):
ptype = param_types.get(param.name, "scalar")
if param.is_variadic: if param.is_variadic:
self.emit(f'local -a {param.name}=("${{@:{i + 1}}}")') self.emit(f'local -a {param.name}=("${{@:{i + 1}}}")')
elif ptype in ("array", "dict"):
nameref_name = f"__ct_{func.name}_{param.name}"
self.emit(f'local -n {nameref_name}="${{{i + 1}}}"')
self.param_name_map[param.name] = nameref_name
if ptype == "array":
self.array_vars.add(nameref_name)
else:
self.dict_vars.add(nameref_name)
else: else:
if param.default is not None: if param.default is not None:
default_val = self.generate_expr(param.default) default_val = self.generate_expr(param.default)
...@@ -387,6 +506,7 @@ class ClassMixin: ...@@ -387,6 +506,7 @@ class ClassMixin:
self.deferred_calls = old_deferred self.deferred_calls = old_deferred
self.in_function = old_in_function self.in_function = old_in_function
self.local_vars = old_local_vars self.local_vars = old_local_vars
self.param_name_map = old_param_name_map
self.indent_level -= 1 self.indent_level -= 1
self.emit("}") self.emit("}")
......
from typing import List, Dict, Optional, Set from typing import List, Dict, Optional, Set
from .ast_nodes import * from .ast_nodes import Program, ClassDecl, FunctionDecl, Assignment, Identifier
from .errors import ErrorCollector from .errors import ErrorCollector
from .stdlib import StdlibMixin from .stdlib import StdlibMixin
from .awk_codegen import AwkCodegenMixin from .awk_codegen import AwkCodegenMixin
...@@ -52,6 +52,7 @@ class CodeGenerator(StdlibMixin, AwkCodegenMixin, ExprMixin, StmtMixin, ...@@ -52,6 +52,7 @@ class CodeGenerator(StdlibMixin, AwkCodegenMixin, ExprMixin, StmtMixin,
self.nameref_vars: Set[str] = set() # vars that are namerefs to arrays/dicts self.nameref_vars: Set[str] = set() # vars that are namerefs to arrays/dicts
self.instance_vars: Dict[str, str] = {} # var_name -> class_name self.instance_vars: Dict[str, str] = {} # var_name -> class_name
self.class_field_types: Dict[tuple, str] = {} self.class_field_types: Dict[tuple, str] = {}
self.func_param_types: Dict[tuple, str] = {} # (func_name, param_name) -> "array"/"dict"
self.local_vars: Set[str] = set() self.local_vars: Set[str] = set()
self.current_param_positions: Dict[str, int] = {} # param_name -> position (1-based) self.current_param_positions: Dict[str, int] = {} # param_name -> position (1-based)
...@@ -80,6 +81,21 @@ class CodeGenerator(StdlibMixin, AwkCodegenMixin, ExprMixin, StmtMixin, ...@@ -80,6 +81,21 @@ class CodeGenerator(StdlibMixin, AwkCodegenMixin, ExprMixin, StmtMixin,
def indent(self) -> str: def indent(self) -> str:
return " " * self.indent_level return " " * self.indent_level
class _IndentContext:
def __init__(self, gen):
self.gen = gen
def __enter__(self):
self.gen.indent_level += 1
return self
def __exit__(self, *_):
self.gen.indent_level -= 1
def indented(self):
"""Context manager for indented code blocks."""
return self._IndentContext(self)
def emit(self, line: str = ""): def emit(self, line: str = ""):
if line: if line:
self.output.append(f"{self.indent()}{line}") self.output.append(f"{self.indent()}{line}")
......
"""Constants for bash code generation."""
RET_VAR = "__CT_RET"
TMP_PREFIX = "__ct_tmp_"
CLASS_FUNC_PREFIX = "__ct_class_"
LAMBDA_PREFIX = "__ct_lambda_"
OBJ_STORE = "__CT_OBJ"
THIS_INSTANCE = "__ct_this_instance"
ARR_FUNC_PREFIX = "__ct_arr_"
DICT_FUNC_PREFIX = "__ct_dict_"
STR_FUNC_PREFIX = "__ct_str_"
FH_FUNC_PREFIX = "__ct_fh_"
HTTP_FUNC_PREFIX = "__ct_http_"
FS_FUNC_PREFIX = "__ct_fs_"
JSON_FUNC_PREFIX = "__ct_json_"
REGEX_FUNC_PREFIX = "__ct_regex_"
MATH_FUNC_PREFIX = "__ct_math_"
from .ast_nodes import * from typing import Dict, Any
from .ast_nodes import (
Expression, CallExpr, MemberAccess, ThisExpr, Identifier,
BinaryOp, UnaryOp, BoolLiteral
)
class NodeIdMap:
"""Mapping from AST nodes to values using id() with reference retention."""
def __init__(self):
self._map: Dict[int, Any] = {}
self._refs = []
def set(self, node, value):
self._refs.append(node)
self._map[id(node)] = value
def get(self, node, default=None):
return self._map.get(id(node), default)
def __contains__(self, node):
return id(node) in self._map
def __getitem__(self, node):
return self._map[id(node)]
class CseMixin: class CseMixin:
...@@ -39,7 +64,7 @@ class CseMixin: ...@@ -39,7 +64,7 @@ class CseMixin:
self.collect_method_calls(condition, calls) self.collect_method_calls(condition, calls)
seen = {} seen = {}
mapping = {} mapping = NodeIdMap()
regen_code = [] regen_code = []
for call in calls: for call in calls:
...@@ -56,7 +81,7 @@ class CseMixin: ...@@ -56,7 +81,7 @@ class CseMixin:
self.emit(assign_line) self.emit(assign_line)
seen[key] = temp seen[key] = temp
regen_code.append((call_line, assign_line)) regen_code.append((call_line, assign_line))
mapping[id(call)] = seen[key] mapping.set(call, seen[key])
return mapping, regen_code return mapping, regen_code
...@@ -66,7 +91,7 @@ class CseMixin: ...@@ -66,7 +91,7 @@ class CseMixin:
self.collect_all_calls(condition, calls) self.collect_all_calls(condition, calls)
seen = {} seen = {}
mapping = {} mapping = NodeIdMap()
regen_code = [] regen_code = []
for call in calls: for call in calls:
...@@ -85,7 +110,7 @@ class CseMixin: ...@@ -85,7 +110,7 @@ class CseMixin:
self.emit(assign_line) self.emit(assign_line)
seen[key] = temp seen[key] = temp
regen_code.append((call_line, assign_line)) regen_code.append((call_line, assign_line))
mapping[id(call)] = seen[key] mapping.set(call, seen[key])
elif isinstance(call.callee.object, Identifier): elif isinstance(call.callee.object, Identifier):
obj_name = call.callee.object.name obj_name = call.callee.object.name
...@@ -104,7 +129,7 @@ class CseMixin: ...@@ -104,7 +129,7 @@ class CseMixin:
self.emit(call_line) self.emit(call_line)
seen[key] = temp seen[key] = temp
regen_code.append((call_line, "")) regen_code.append((call_line, ""))
mapping[id(call)] = seen[key] mapping.set(call, seen[key])
elif isinstance(call.callee, Identifier): elif isinstance(call.callee, Identifier):
func_name = call.callee.name func_name = call.callee.name
...@@ -120,7 +145,7 @@ class CseMixin: ...@@ -120,7 +145,7 @@ class CseMixin:
self.emit(assign_line) self.emit(assign_line)
seen[key] = temp seen[key] = temp
regen_code.append((call_line, assign_line)) regen_code.append((call_line, assign_line))
mapping[id(call)] = seen[key] mapping.set(call, seen[key])
return mapping, regen_code return mapping, regen_code
...@@ -170,19 +195,19 @@ class CseMixin: ...@@ -170,19 +195,19 @@ class CseMixin:
if isinstance(expr, BoolLiteral): if isinstance(expr, BoolLiteral):
return "true" if expr.value else "false" return "true" if expr.value else "false"
if isinstance(expr, CallExpr) and id(expr) in mapping: if isinstance(expr, CallExpr) and expr in mapping:
return f'[[ "${{{mapping[id(expr)]}}}" == "true" ]]' return f'[[ "${{{mapping[expr]}}}" == "true" ]]'
return self.generate_condition(expr) return self.generate_condition(expr)
def generate_expr_with_precompute(self, expr: Expression, mapping: dict) -> str: def generate_expr_with_precompute(self, expr: Expression, mapping: NodeIdMap) -> str:
"""Generate expression using pre-computed values.""" """Generate expression using pre-computed values."""
if isinstance(expr, CallExpr) and id(expr) in mapping: if isinstance(expr, CallExpr) and expr in mapping:
return f'${mapping[id(expr)]}' return f'${mapping[expr]}'
if isinstance(expr, MemberAccess): if isinstance(expr, MemberAccess):
if isinstance(expr.object, CallExpr) and id(expr.object) in mapping: if isinstance(expr.object, CallExpr) and expr.object in mapping:
temp = mapping[id(expr.object)] temp = mapping[expr.object]
return f'${{__CT_OBJ["${temp}.{expr.member}"]}}' return f'${{__CT_OBJ["${temp}.{expr.member}"]}}'
return self.generate_expr(expr) return self.generate_expr(expr)
......
"""Dead Code Elimination.""" """Dead Code Elimination."""
from .ast_nodes import * import re
from .ast_nodes import (
ClassDecl, NewExpr, CallExpr, Identifier, FunctionDecl, Assignment,
ExpressionStmt, IfStmt, ForStmt, ForeachStmt, WhileStmt, WhenStmt,
WhenBranch, TryStmt, ThrowStmt, DeferStmt, ReturnStmt, ArrayLiteral,
DictLiteral, IndexAccess, Lambda, MemberAccess, ThisExpr, Block,
BinaryOp, UnaryOp, WithStmt, Program, StringLiteral
)
class UsageAnalyzer: class UsageAnalyzer:
...@@ -11,19 +19,26 @@ class UsageAnalyzer: ...@@ -11,19 +19,26 @@ class UsageAnalyzer:
'args', 'misc', 'args', 'misc',
} }
ARRAY_RETURNING_METHODS = {'keys', 'split', 'slice'}
DICT_RETURNING_METHODS = set()
def __init__(self): def __init__(self):
self.used: set = set() self.used: set = set()
self.has_classes = False self.has_classes = False
self.has_awk = False self.has_awk = False
self.test_mode = False self.test_mode = False
self.defined_classes: dict = {} self.defined_classes: dict = {}
self.defined_functions: dict = {}
self.used_classes: set = set() self.used_classes: set = set()
self.used_methods: dict = {} self.used_methods: dict = {}
self.class_fields: dict = {} self.class_fields: dict = {}
self.variable_types: dict = {} self.variable_types: dict = {}
self.array_variables: set = set()
self.dict_variables: set = set()
self.current_class_name: str = None self.current_class_name: str = None
self.current_method_name: str = None self.current_method_name: str = None
self.method_calls: dict = {} self.method_calls: dict = {}
self.func_param_types: dict = {}
def analyze(self, programs: list, test_mode: bool = False) -> set: def analyze(self, programs: list, test_mode: bool = False) -> set:
self.used = {'core'} self.used = {'core'}
...@@ -34,6 +49,8 @@ class UsageAnalyzer: ...@@ -34,6 +49,8 @@ class UsageAnalyzer:
if isinstance(stmt, ClassDecl): if isinstance(stmt, ClassDecl):
self.defined_classes[stmt.name] = stmt self.defined_classes[stmt.name] = stmt
self._collect_class_fields(stmt) self._collect_class_fields(stmt)
elif isinstance(stmt, FunctionDecl):
self.defined_functions[stmt.name] = stmt
for program in programs: for program in programs:
for stmt in program.statements: for stmt in program.statements:
...@@ -140,7 +157,10 @@ class UsageAnalyzer: ...@@ -140,7 +157,10 @@ class UsageAnalyzer:
for dec in stmt.decorators: for dec in stmt.decorators:
if dec.name == 'awk': if dec.name == 'awk':
self.has_awk = True self.has_awk = True
old_func_name = getattr(self, 'current_func_name', None)
self.current_func_name = stmt.name
self._analyze_body(stmt.body) self._analyze_body(stmt.body)
self.current_func_name = old_func_name
elif isinstance(stmt, Assignment): elif isinstance(stmt, Assignment):
self._analyze_expr(stmt.value) self._analyze_expr(stmt.value)
...@@ -148,10 +168,21 @@ class UsageAnalyzer: ...@@ -148,10 +168,21 @@ class UsageAnalyzer:
var_name = stmt.target.name var_name = stmt.target.name
if isinstance(stmt.value, NewExpr): if isinstance(stmt.value, NewExpr):
self.variable_types[var_name] = stmt.value.class_name self.variable_types[var_name] = stmt.value.class_name
elif isinstance(stmt.value, CallExpr) and isinstance(stmt.value.callee, Identifier): elif isinstance(stmt.value, CallExpr):
callee_name = stmt.value.callee.name if isinstance(stmt.value.callee, Identifier):
if callee_name in self.defined_classes: callee_name = stmt.value.callee.name
self.variable_types[var_name] = callee_name if callee_name in self.defined_classes:
self.variable_types[var_name] = callee_name
elif isinstance(stmt.value.callee, MemberAccess):
method = stmt.value.callee.member
if method in self.ARRAY_RETURNING_METHODS:
self.array_variables.add(var_name)
elif method in self.DICT_RETURNING_METHODS:
self.dict_variables.add(var_name)
elif isinstance(stmt.value, ArrayLiteral):
self.array_variables.add(var_name)
elif isinstance(stmt.value, DictLiteral):
self.dict_variables.add(var_name)
elif isinstance(stmt, ExpressionStmt): elif isinstance(stmt, ExpressionStmt):
self._analyze_expr(stmt.expression) self._analyze_expr(stmt.expression)
...@@ -262,6 +293,36 @@ class UsageAnalyzer: ...@@ -262,6 +293,36 @@ class UsageAnalyzer:
elif isinstance(expr, Identifier): elif isinstance(expr, Identifier):
pass pass
elif isinstance(expr, StringLiteral):
if getattr(expr, 'has_interpolation', False):
self._analyze_string_interpolation(expr.value)
def _analyze_string_interpolation(self, value: str):
"""Analyze method calls in string interpolation like {var.method()}."""
pattern = r'\{(\w+)\.(\w+)\s*\([^)]*\)\}'
for match in re.finditer(pattern, value):
var_name = match.group(1)
method = match.group(2)
if var_name in self.variable_types:
obj_class = self.variable_types[var_name]
if obj_class not in self.used_methods:
self.used_methods[obj_class] = set()
self.used_methods[obj_class].add(method)
elif hasattr(self, 'current_func_name') and self.current_func_name:
key = (self.current_func_name, var_name)
if key in self.func_param_types:
for obj_class in self.func_param_types[key]:
if obj_class not in self.used_methods:
self.used_methods[obj_class] = set()
self.used_methods[obj_class].add(method)
else:
for cls_name, cls_decl in self.defined_classes.items():
for m in cls_decl.methods:
if m.name == method:
if cls_name not in self.used_methods:
self.used_methods[cls_name] = set()
self.used_methods[cls_name].add(method)
def _analyze_call(self, expr: CallExpr): def _analyze_call(self, expr: CallExpr):
callee = expr.callee callee = expr.callee
...@@ -274,6 +335,8 @@ class UsageAnalyzer: ...@@ -274,6 +335,8 @@ class UsageAnalyzer:
self.used.add('array') self.used.add('array')
elif callee.name in ('random', 'random_range'): elif callee.name in ('random', 'random_range'):
self.used.add('misc') self.used.add('misc')
elif callee.name in self.defined_functions:
self._analyze_function_call_with_types(callee.name, expr.arguments)
if isinstance(callee, MemberAccess): if isinstance(callee, MemberAccess):
if isinstance(callee.object, ThisExpr): if isinstance(callee.object, ThisExpr):
...@@ -315,6 +378,19 @@ class UsageAnalyzer: ...@@ -315,6 +378,19 @@ class UsageAnalyzer:
if obj_class not in self.used_methods: if obj_class not in self.used_methods:
self.used_methods[obj_class] = set() self.used_methods[obj_class] = set()
self.used_methods[obj_class].add(method) self.used_methods[obj_class].add(method)
elif ns in self.array_variables:
self.used.add('array')
elif ns in self.dict_variables:
self.used.add('dict')
elif hasattr(self, 'current_func_name') and self.current_func_name:
key = (self.current_func_name, ns)
if key in self.func_param_types:
for obj_class in self.func_param_types[key]:
if obj_class not in self.used_methods:
self.used_methods[obj_class] = set()
self.used_methods[obj_class].add(method)
else:
self._check_method(method)
elif ns == 'http': elif ns == 'http':
self.used.add('http') self.used.add('http')
elif ns == 'fs': elif ns == 'fs':
...@@ -347,6 +423,46 @@ class UsageAnalyzer: ...@@ -347,6 +423,46 @@ class UsageAnalyzer:
if not found_in_class: if not found_in_class:
self._check_method(method) self._check_method(method)
def _analyze_function_call_with_types(self, func_name: str, arguments: list):
"""Analyze function call and propagate object types to parameters."""
func_decl = self.defined_functions.get(func_name)
if not func_decl:
return
new_types_added = False
for i, arg in enumerate(arguments):
if i >= len(func_decl.params):
break
param_name = func_decl.params[i].name
arg_type = None
if isinstance(arg, Identifier):
arg_type = self.variable_types.get(arg.name)
if not arg_type and hasattr(self, 'current_func_name') and self.current_func_name:
key = (self.current_func_name, arg.name)
if key in self.func_param_types:
for t in self.func_param_types[key]:
arg_type = t
break
elif isinstance(arg, NewExpr):
arg_type = arg.class_name
elif isinstance(arg, CallExpr) and isinstance(arg.callee, Identifier):
if arg.callee.name in self.defined_classes:
arg_type = arg.callee.name
if arg_type and arg_type in self.defined_classes:
key = (func_name, param_name)
if key not in self.func_param_types:
self.func_param_types[key] = set()
if arg_type not in self.func_param_types[key]:
self.func_param_types[key].add(arg_type)
new_types_added = True
if new_types_added:
old_func = getattr(self, 'current_func_name', None)
self.current_func_name = func_name
self._analyze_body(func_decl.body)
self.current_func_name = old_func
def _analyze_member_access(self, expr: MemberAccess): def _analyze_member_access(self, expr: MemberAccess):
if isinstance(expr.object, Identifier): if isinstance(expr.object, Identifier):
pass pass
......
import re import re
from typing import List from typing import List
from .ast_nodes import * from .ast_nodes import Decorator, Parameter
class DecoratorMixin: class DecoratorMixin:
......
import re import re
from .ast_nodes import * from .ast_nodes import (
Expression, IntegerLiteral, FloatLiteral, StringLiteral, BoolLiteral,
NilLiteral, Identifier, ThisExpr, ArrayLiteral, DictLiteral, BinaryOp,
UnaryOp, CallExpr, MemberAccess, IndexAccess, Lambda, NewExpr, BaseCall,
Block, ReturnStmt
)
class ExprMixin: class ExprMixin:
...@@ -22,7 +27,11 @@ class ExprMixin: ...@@ -22,7 +27,11 @@ class ExprMixin:
return "" return ""
if isinstance(expr, Identifier): if isinstance(expr, Identifier):
return f"${{{expr.name}}}" name = expr.name
param_map = getattr(self, 'param_name_map', {})
if name in param_map:
name = param_map[name]
return f"${{{name}}}"
if isinstance(expr, ThisExpr): if isinstance(expr, ThisExpr):
return "$this" return "$this"
......
...@@ -176,7 +176,7 @@ def cmd_run (args): ...@@ -176,7 +176,7 @@ def cmd_run (args):
finally: finally:
try: try:
os.unlink (temp_path) os.unlink (temp_path)
except: except OSError:
pass pass
except Exception as e: except Exception as e:
...@@ -238,7 +238,7 @@ def cmd_test (args): ...@@ -238,7 +238,7 @@ def cmd_test (args):
finally: finally:
try: try:
os.unlink (temp_path) os.unlink (temp_path)
except: except OSError:
pass pass
except Exception as e: except Exception as e:
......
"""Unified method registry for bash and AWK code generation.
This module provides a single source of truth for all builtin methods,
ensuring consistency between bash and AWK code generators.
"""
from dataclasses import dataclass
from typing import Optional, Callable, List
@dataclass
class MethodDef:
"""Definition of a builtin method."""
name: str
min_args: int = 0
max_args: Optional[int] = None
bash_func: Optional[str] = None
awk_gen: Optional[Callable[[str, List[str]], str]] = None
returns_array: bool = False
STRING_METHODS = {
"len": MethodDef("len", 0, 0, "__ct_str_len",
lambda obj, args: f"length({obj})"),
"upper": MethodDef("upper", 0, 0, "__ct_str_upper",
lambda obj, args: f"toupper({obj})"),
"lower": MethodDef("lower", 0, 0, "__ct_str_lower",
lambda obj, args: f"tolower({obj})"),
"trim": MethodDef("trim", 0, 0, "__ct_str_trim",
lambda obj, args: f'(gsub(/^[ \\t]+|[ \\t]+$/, "", {obj}) ? {obj} : {obj})'),
"contains": MethodDef("contains", 1, 1, "__ct_str_contains",
lambda obj, args: f"(index({obj}, {args[0]}) > 0)"),
"starts": MethodDef("starts", 1, 1, "__ct_str_starts",
lambda obj, args: f"(substr({obj}, 1, length({args[0]})) == {args[0]})"),
"ends": MethodDef("ends", 1, 1, "__ct_str_ends",
lambda obj, args: f"(substr({obj}, length({obj}) - length({args[0]}) + 1) == {args[0]})"),
"index": MethodDef("index", 1, 1, "__ct_str_index",
lambda obj, args: f"(index({obj}, {args[0]}) - 1)"),
"replace": MethodDef("replace", 2, 2, "__ct_str_replace",
lambda obj, args: f"(gsub({args[0]}, {args[1]}, {obj}) ? {obj} : {obj})"),
"substr": MethodDef("substr", 2, 2, "__ct_str_substr",
lambda obj, args: f"substr({obj}, {args[0]} + 1, {args[1]})"),
"split": MethodDef("split", 1, 1, "__ct_str_split",
lambda obj, args: f"split({obj}, __split_arr, {args[0]})",
returns_array=True),
"charAt": MethodDef("charAt", 1, 1, "__ct_str_char_at",
lambda obj, args: f"substr({obj}, {args[0]} + 1, 1)"),
"urlencode": MethodDef("urlencode", 0, 0, "__ct_str_urlencode", None),
}
ARRAY_METHODS = {
"len": MethodDef("len", 0, 0, "__ct_arr_len",
lambda obj, args: f"length({obj})"),
"push": MethodDef("push", 1, 1, "__ct_arr_push",
lambda obj, args: f"{obj}[length({obj}) + 1] = {args[0]}"),
"pop": MethodDef("pop", 0, 0, "__ct_arr_pop",
lambda obj, args: f"delete {obj}[length({obj})]"),
"shift": MethodDef("shift", 0, 0, "__ct_arr_shift",
lambda obj, args: f"delete {obj}[1]"),
"join": MethodDef("join", 1, 1, "__ct_arr_join",
lambda obj, args: f"__ct_awk_join({obj}, {args[0]})"),
"get": MethodDef("get", 1, 1, "__ct_arr_get",
lambda obj, args: f"{obj}[{args[0]}]"),
"set": MethodDef("set", 2, 2, "__ct_arr_set",
lambda obj, args: f"{obj}[{args[0]}] = {args[1]}"),
"slice": MethodDef("slice", 2, 2, "__ct_arr_slice", None, returns_array=True),
"map": MethodDef("map", 1, 1, "__ct_arr_map", None, returns_array=True),
"filter": MethodDef("filter", 1, 1, "__ct_arr_filter", None, returns_array=True),
}
DICT_METHODS = {
"get": MethodDef("get", 1, 1, "__ct_dict_get",
lambda obj, args: f"{obj}[{args[0]}]"),
"set": MethodDef("set", 2, 2, "__ct_dict_set",
lambda obj, args: f"{obj}[{args[0]}] = {args[1]}"),
"has": MethodDef("has", 1, 1, "__ct_dict_has",
lambda obj, args: f"({args[0]} in {obj})"),
"del": MethodDef("del", 1, 1, "__ct_dict_del",
lambda obj, args: f"delete {obj}[{args[0]}]"),
"keys": MethodDef("keys", 0, 0, "__ct_dict_keys", None, returns_array=True),
}
FILE_HANDLE_METHODS = {
"read": MethodDef("read", 0, 0, "__ct_fh_read", None),
"readline": MethodDef("readline", 0, 0, "__ct_fh_readline", None),
"write": MethodDef("write", 1, 1, "__ct_fh_write", None),
"writeln": MethodDef("writeln", 1, 1, "__ct_fh_writeln", None),
"close": MethodDef("close", 0, 0, "__ct_fh_close", None),
}
NAMESPACE_METHODS = {
"fs": {"read", "write", "append", "exists", "remove", "mkdir", "list", "open"},
"http": {"get", "post", "put", "delete"},
"json": {"parse", "stringify", "get"},
"logger": {"info", "warn", "error", "debug"},
"regex": {"match", "extract"},
"args": {"count", "get"},
"shell": {"exec", "capture", "source"},
"time": {"now", "ms"},
"math": {"add", "sub", "mul", "div", "mod", "min", "max", "abs"},
}
BUILTIN_NAMESPACES = set(NAMESPACE_METHODS.keys())
BUILTIN_FUNCS = {"print", "exit", "len", "range", "ngrep", "is_number",
"is_empty", "chr", "ord", "assert", "assert_eq", "random", "random_range"}
def get_method_names(type_name: str) -> set:
"""Get all available method names for a type."""
if type_name == "string":
return set(STRING_METHODS.keys())
elif type_name == "array":
return set(ARRAY_METHODS.keys())
elif type_name == "dict":
return set(DICT_METHODS.keys())
elif type_name == "file_handle":
return set(FILE_HANDLE_METHODS.keys())
return set()
def get_method_def(type_name: str, method_name: str) -> Optional[MethodDef]:
"""Get method definition by type and name."""
methods = {
"string": STRING_METHODS,
"array": ARRAY_METHODS,
"dict": DICT_METHODS,
"file_handle": FILE_HANDLE_METHODS,
}
return methods.get(type_name, {}).get(method_name)
def get_bash_func(type_name: str, method_name: str) -> Optional[str]:
"""Get bash function name for a method."""
method = get_method_def(type_name, method_name)
return method.bash_func if method else None
def generate_awk(type_name: str, method_name: str, obj: str, args: List[str]) -> Optional[str]:
"""Generate AWK code for a method call."""
method = get_method_def(type_name, method_name)
if method and method.awk_gen:
return method.awk_gen(obj, args)
return None
from typing import List, Optional, Callable from typing import List, Optional, Callable, Union
from .tokens import Token, TokenType from .tokens import Token, TokenType
from .ast_nodes import * from .ast_nodes import (
SourceLocation, Program, Declaration, Statement, Decorator, FunctionDecl,
Parameter, ClassDecl, ConstructorDecl, ImportStmt, Block, ReturnStmt,
BreakStmt, ContinueStmt, IfStmt, WhileStmt, ForStmt, ForeachStmt, WithStmt,
TryStmt, ThrowStmt, DeferStmt, WhenStmt, WhenBranch, RangePattern,
ExpressionStmt, Assignment, IntegerLiteral, FloatLiteral, StringLiteral,
BoolLiteral, NilLiteral, ThisExpr, ArrayLiteral, DictLiteral, Identifier,
BinaryOp, UnaryOp, CallExpr, MemberAccess, IndexAccess, NewExpr, Lambda,
BaseCall, Expression
)
from .errors import CompileError, ErrorCollector from .errors import CompileError, ErrorCollector
...@@ -741,7 +750,7 @@ class Parser: ...@@ -741,7 +750,7 @@ class Parser:
return self.parse_lambda_body (params, loc) return self.parse_lambda_body (params, loc)
self.pos = saved_pos self.pos = saved_pos
except: except Exception:
self.pos = saved_pos self.pos = saved_pos
expr = self.parse_expression () expr = self.parse_expression ()
......
...@@ -458,7 +458,7 @@ class StdlibMixin: ...@@ -458,7 +458,7 @@ class StdlibMixin:
self.emit ("__ct_str_starts () { [[ \"$1\" == \"$2\"* ]] && __CT_RET=true || __CT_RET=false; echo \"$__CT_RET\"; }") self.emit ("__ct_str_starts () { [[ \"$1\" == \"$2\"* ]] && __CT_RET=true || __CT_RET=false; echo \"$__CT_RET\"; }")
self.emit ("__ct_str_ends () { [[ \"$1\" == *\"$2\" ]] && __CT_RET=true || __CT_RET=false; echo \"$__CT_RET\"; }") self.emit ("__ct_str_ends () { [[ \"$1\" == *\"$2\" ]] && __CT_RET=true || __CT_RET=false; echo \"$__CT_RET\"; }")
self.emit ("__ct_str_replace () { __CT_RET=\"${1//\"$2\"/\"$3\"}\"; echo \"$__CT_RET\"; }") self.emit ("__ct_str_replace () { __CT_RET=\"${1//\"$2\"/\"$3\"}\"; echo \"$__CT_RET\"; }")
self.emit ("__ct_str_split () { local IFS=\"$2\"; read -ra __arr <<< \"$1\"; printf '%s\\n' \"${__arr[@]}\"; }") self.emit ("__ct_str_split () { local IFS=\"$2\"; read -ra __CT_RET_ARR <<< \"$1\"; }")
self.emit ("__ct_str_trim () { local s=\"$1\"; s=\"${s#\"${s%%[![:space:]]*}\"}\"; __CT_RET=\"${s%\"${s##*[![:space:]]}\"}\" ; echo \"$__CT_RET\"; }") self.emit ("__ct_str_trim () { local s=\"$1\"; s=\"${s#\"${s%%[![:space:]]*}\"}\"; __CT_RET=\"${s%\"${s##*[![:space:]]}\"}\" ; echo \"$__CT_RET\"; }")
self.emit ("__ct_str_upper () { __CT_RET=\"${1^^}\"; echo \"$__CT_RET\"; }") self.emit ("__ct_str_upper () { __CT_RET=\"${1^^}\"; echo \"$__CT_RET\"; }")
self.emit ("__ct_str_lower () { __CT_RET=\"${1,,}\"; echo \"$__CT_RET\"; }") self.emit ("__ct_str_lower () { __CT_RET=\"${1,,}\"; echo \"$__CT_RET\"; }")
...@@ -482,7 +482,7 @@ class StdlibMixin: ...@@ -482,7 +482,7 @@ class StdlibMixin:
self.emit ("__ct_arr_len () { local -n __a=$1; __CT_RET=${#__a[@]}; echo \"$__CT_RET\"; }") self.emit ("__ct_arr_len () { local -n __a=$1; __CT_RET=${#__a[@]}; echo \"$__CT_RET\"; }")
self.emit ("__ct_arr_get () { local -n __a=$1; __CT_RET=\"${__a[$2]}\"; echo \"$__CT_RET\"; }") self.emit ("__ct_arr_get () { local -n __a=$1; __CT_RET=\"${__a[$2]}\"; echo \"$__CT_RET\"; }")
self.emit ("__ct_arr_set () { local -n __a=$1; __a[$2]=\"$3\"; }") self.emit ("__ct_arr_set () { local -n __a=$1; __a[$2]=\"$3\"; }")
self.emit ("__ct_arr_slice () { local -n __a=$1; local -a __r=(\"${__a[@]:$2:$3}\"); printf '%s\\n' \"${__r[@]}\"; }") self.emit ("__ct_arr_slice () { local -n __a=$1; __CT_RET_ARR=(\"${__a[@]:$2:$3}\"); }")
self.emit () self.emit ()
self.emit ("# Array map/filter with lambda functions") self.emit ("# Array map/filter with lambda functions")
...@@ -628,7 +628,8 @@ class StdlibMixin: ...@@ -628,7 +628,8 @@ class StdlibMixin:
self.emit ('__ct_dict_get () { local -n __d="$1"; __CT_RET="${__d[$2]}"; echo "$__CT_RET"; }') self.emit ('__ct_dict_get () { local -n __d="$1"; __CT_RET="${__d[$2]}"; echo "$__CT_RET"; }')
self.emit ('__ct_dict_has () { local -n __d="$1"; [[ -v "__d[$2]" ]] && __CT_RET=true || __CT_RET=false; echo "$__CT_RET"; }') self.emit ('__ct_dict_has () { local -n __d="$1"; [[ -v "__d[$2]" ]] && __CT_RET=true || __CT_RET=false; echo "$__CT_RET"; }')
self.emit ('__ct_dict_del () { local -n __d="$1"; unset "__d[$2]"; }') self.emit ('__ct_dict_del () { local -n __d="$1"; unset "__d[$2]"; }')
self.emit ('__ct_dict_keys () { local -n __d="$1"; printf \'%s\\n\' "${!__d[@]}"; }') self.emit ('__ct_dict_keys () { local -n __d="$1"; __CT_RET_ARR=("${!__d[@]}"); }')
self.emit ('__ct_dict_len () { local -n __d="$1"; __CT_RET=${#__d[@]}; echo "$__CT_RET"; }')
self.emit () self.emit ()
def _emit_misc (self): def _emit_misc (self):
......
from .ast_nodes import * from .ast_nodes import (
FunctionDecl, ClassDecl, ImportStmt, Assignment, ExpressionStmt, IfStmt,
WhileStmt, ForStmt, ForeachStmt, WithStmt, TryStmt, ThrowStmt, DeferStmt,
WhenStmt, RangePattern, ReturnStmt, BreakStmt, ContinueStmt, Block,
CallExpr, Identifier, MemberAccess, ThisExpr, StringLiteral, NewExpr,
BinaryOp, DictLiteral, ArrayLiteral, WhenBranch
)
class StmtMixin: class StmtMixin:
...@@ -250,6 +256,8 @@ class StmtMixin: ...@@ -250,6 +256,8 @@ class StmtMixin:
if isinstance(stmt.iterable, Identifier): if isinstance(stmt.iterable, Identifier):
arr_name = stmt.iterable.name arr_name = stmt.iterable.name
param_map = getattr(self, 'param_name_map', {})
arr_name = param_map.get(arr_name, arr_name)
if len(stmt.variables) == 1: if len(stmt.variables) == 1:
var = stmt.variables[0] var = stmt.variables[0]
self.emit(f'for {var} in "${{{arr_name}[@]}}"; do') self.emit(f'for {var} in "${{{arr_name}[@]}}"; do')
...@@ -293,6 +301,27 @@ class StmtMixin: ...@@ -293,6 +301,27 @@ class StmtMixin:
self.emit("done") self.emit("done")
return return
if stmt.iterable.callee.member == "split" and len(stmt.iterable.arguments) == 1:
str_expr = self.generate_expr(stmt.iterable.callee.object)
delim_arg = self.generate_expr(stmt.iterable.arguments[0])
var = stmt.variables[0]
self.emit(f'__ct_str_split "{str_expr}" "{delim_arg}"')
if len(stmt.variables) == 1:
self.emit(f'for {var} in "${{__CT_RET_ARR[@]}}"; do')
else:
idx_var = stmt.variables[0]
val_var = stmt.variables[1]
self.emit(f'{idx_var}=0')
self.emit(f'for {val_var} in "${{__CT_RET_ARR[@]}}"; do')
self.indent_level += 1
for s in stmt.body.statements:
self.generate_statement(s)
if len(stmt.variables) == 2:
self.emit(f'((++{stmt.variables[0]}))')
self.indent_level -= 1
self.emit("done")
return
iterable = self.generate_expr(stmt.iterable) iterable = self.generate_expr(stmt.iterable)
var = stmt.variables[0] var = stmt.variables[0]
self.emit(f'for {var} in {iterable}; do') self.emit(f'for {var} in {iterable}; do')
...@@ -569,26 +598,29 @@ class StmtMixin: ...@@ -569,26 +598,29 @@ class StmtMixin:
args_str = " ".join([f'"{a}"' for a in args]) args_str = " ".join([f'"{a}"' for a in args])
if obj_name in self.array_vars and method in arr_methods: param_map = getattr(self, 'param_name_map', {})
mapped_name = param_map.get(obj_name, obj_name)
if mapped_name in self.array_vars and method in arr_methods:
func_name = arr_methods[method] func_name = arr_methods[method]
self.emit(f'{func_name} "{obj_name}" {args_str} >/dev/null'.replace(' ', ' ')) self.emit(f'{func_name} "{mapped_name}" {args_str} >/dev/null'.replace(' ', ' '))
self.emit('echo "$__CT_RET"') self.emit('echo "$__CT_RET"')
self.emit("return 0") self.emit("return 0")
return True return True
elif obj_name in self.dict_vars and method in dict_methods: elif mapped_name in self.dict_vars and method in dict_methods:
func_name = dict_methods[method] func_name = dict_methods[method]
self.emit(f'{func_name} "{obj_name}" {args_str} >/dev/null'.replace(' ', ' ')) self.emit(f'{func_name} "{mapped_name}" {args_str} >/dev/null'.replace(' ', ' '))
self.emit('echo "$__CT_RET"') self.emit('echo "$__CT_RET"')
self.emit("return 0") self.emit("return 0")
return True return True
elif method in str_methods: elif method in str_methods:
func_name = str_methods[method] func_name = str_methods[method]
self.emit(f'{func_name} "${{{obj_name}}}" {args_str} >/dev/null'.replace(' ', ' ')) self.emit(f'{func_name} "${{{mapped_name}}}" {args_str} >/dev/null'.replace(' ', ' '))
self.emit('echo "$__CT_RET"') self.emit('echo "$__CT_RET"')
self.emit("return 0") self.emit("return 0")
return True return True
self.emit(f'__ct_call_method "${{{obj_name}}}" "{method}" {args_str} >/dev/null') self.emit(f'__ct_call_method "${{{mapped_name}}}" "{method}" {args_str} >/dev/null')
self.emit('echo "$__CT_RET"') self.emit('echo "$__CT_RET"')
self.emit("return 0") self.emit("return 0")
return True return True
...@@ -1077,3 +1077,330 @@ x = arr.foo() ...@@ -1077,3 +1077,330 @@ x = arr.foo()
assert "Available:" in stdout assert "Available:" in stdout
assert "push" in stdout assert "push" in stdout
assert "pop" in stdout assert "pop" in stdout
class TestFunctionParameterPassing:
"""Tests for passing arrays and dicts to user-defined functions."""
def test_array_modification_in_function(self):
code, stdout, _ = run_ct('''
func modify_arr(arr) {
arr.push(99)
arr.push(100)
}
nums = [1, 2, 3]
print("Before: {nums.len()}")
modify_arr(nums)
print("After: {nums.len()}")
''')
assert code == 0
assert "Before: 3" in stdout
assert "After: 5" in stdout
def test_array_return_length_from_function(self):
code, stdout, _ = run_ct('''
func get_len(arr) {
return arr.len()
}
nums = [1, 2, 3, 4, 5]
len = get_len(nums)
print("Length: {len}")
''')
assert code == 0
assert "Length: 5" in stdout
def test_dict_modification_in_function(self):
code, stdout, _ = run_ct('''
func add_key(d, key, value) {
d.set(key, value)
}
config = {"host": "localhost"}
add_key(config, "port", "8080")
port = config.get("port")
print("Port: {port}")
''')
assert code == 0
assert "Port: 8080" in stdout
def test_dict_has_in_function(self):
code, stdout, _ = run_ct('''
func has_key(d, key) {
return d.has(key)
}
config = {"host": "localhost", "port": "8080"}
has_host = has_key(config, "host")
has_debug = has_key(config, "debug")
print("Has host: {has_host}")
print("Has debug: {has_debug}")
''')
assert code == 0
assert "Has host: true" in stdout
assert "Has debug: false" in stdout
def test_array_get_in_function(self):
code, stdout, _ = run_ct('''
func first_element(arr) {
return arr.get(0)
}
nums = [42, 100, 200]
first = first_element(nums)
print("First: {first}")
''')
assert code == 0
assert "First: 42" in stdout
def test_array_slice_in_function(self):
code, stdout, _ = run_ct('''
func double_all(arr) {
for i in range(arr.len()) {
val = arr.get(i)
arr.set(i, val * 2)
}
}
nums = [1, 2, 3]
double_all(nums)
print("{nums.get(0)} {nums.get(1)} {nums.get(2)}")
''')
assert code == 0
assert "2 4 6" in stdout
def test_nested_function_array_passing(self):
code, stdout, _ = run_ct('''
func inner(arr) {
arr.push("inner")
}
func outer(arr) {
arr.push("outer")
inner(arr)
}
items = []
outer(items)
print("Items: {items.len()}")
''')
assert code == 0
assert "Items: 2" in stdout
def test_dict_keys_in_function(self):
code, stdout, _ = run_ct('''
func count_keys(d) {
keys = d.keys()
return keys.len()
}
config = {"a": "1", "b": "2", "c": "3"}
count = count_keys(config)
print("Key count: {count}")
''')
assert code == 0
assert "Key count: 3" in stdout
def test_mixed_parameters_scalar_and_array(self):
code, stdout, _ = run_ct('''
func add_items(prefix, arr, count) {
for i in range(count) {
arr.push("{prefix}_{i}")
}
}
items = []
add_items("item", items, 3)
print("Count: {items.len()}")
''')
assert code == 0
assert "Count: 3" in stdout
def test_dict_get_in_function(self):
code, stdout, _ = run_ct('''
func get_value(d, key) {
if d.has(key) {
return d.get(key)
}
return ""
}
data = {"name": "Alice", "age": "30"}
name = get_value(data, "name")
print("Name: {name}")
''')
assert code == 0
assert "Name: Alice" in stdout
def test_array_foreach_in_function(self):
code, stdout, _ = run_ct('''
func sum_array(arr) {
total = 0
foreach n in arr {
total += n
}
return total
}
nums = [1, 2, 3, 4, 5]
result = sum_array(nums)
print("Sum: {result}")
''')
assert code == 0
assert "Sum: 15" in stdout
def test_string_split_returns_array(self):
code, stdout, _ = run_ct('''
func count_words(text) {
words = text.split(" ")
return words.len()
}
sentence = "hello world foo bar"
count = count_words(sentence)
print("Word count: {count}")
''')
assert code == 0
assert "Word count: 4" in stdout
def test_array_slice_returns_array(self):
code, stdout, _ = run_ct('''
func first_two(arr) {
sub = arr.slice(0, 2)
return sub.len()
}
nums = [1, 2, 3, 4, 5]
count = first_two(nums)
print("Slice len: {count}")
''')
assert code == 0
assert "Slice len: 2" in stdout
class TestClassInstancePassing:
"""Tests for passing class instances to functions and methods."""
def test_class_to_function_basic(self):
code, stdout, _ = run_ct('''
class Counter {
count = 0
construct(initial) {
this.count = initial
}
func increment() {
this.count += 1
}
func get() {
return this.count
}
}
func add_ten(c) {
for i in range(10) {
c.increment()
}
}
counter = new Counter(5)
add_ten(counter)
print("Result: {counter.get()}")
''')
assert code == 0
assert "Result: 15" in stdout
def test_class_to_function_nested(self):
code, stdout, _ = run_ct('''
class Counter {
count = 0
construct(initial) {
this.count = initial
}
func increment() {
this.count += 1
}
func get() {
return this.count
}
}
func inner(c) {
c.increment()
return c.get()
}
func outer(c) {
c.increment()
return inner(c)
}
counter = new Counter(0)
result = outer(counter)
print("Result: {result}")
''')
assert code == 0
assert "Result: 2" in stdout
def test_class_to_method(self):
code, stdout, _ = run_ct('''
class Counter {
count = 0
construct(initial) {
this.count = initial
}
func get() {
return this.count
}
func add(amount) {
this.count += amount
}
}
class Calculator {
func double_counter(c) {
val = c.get()
c.add(val)
}
}
counter = new Counter(5)
calc = new Calculator()
calc.double_counter(counter)
print("Result: {counter.get()}")
''')
assert code == 0
assert "Result: 10" in stdout
def test_multiple_class_params_to_method(self):
code, stdout, _ = run_ct('''
class Counter {
count = 0
construct(initial) {
this.count = initial
}
func get() {
return this.count
}
func add(amount) {
this.count += amount
}
}
class Calculator {
func process(c1, c2) {
v1 = c1.get()
v2 = c2.get()
c1.add(v2)
c2.add(v1)
}
}
a = new Counter(5)
b = new Counter(10)
calc = new Calculator()
calc.process(a, b)
print("a={a.get()}, b={b.get()}")
''')
assert code == 0
assert "a=15, b=15" in stdout
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment