def extract_function_source(full_text, function_name):
text = full_text.replace("\r\n", "\n")
fence = re.search(r"```(?:python)?\n(.*?)```", text, flags=re.S | re.I)
if fence:
text = fence.group(1)
pattern = rf"^def\s+{re.escape(function_name)}\s*\("
match = re.search(pattern, text, flags=re.M)
if not match:
return ""
chunk = text[match.start():]
lines = chunk.splitlines()
collected = []
for i, line in enumerate(lines):
if i > 0:
if line.startswith("def ") or line.startswith("class "):
break
if line.startswith("if __name__"):
break
if line and not line.startswith((" ", "\t", "#")) and re.match(r"^[A-Za-z_][A-Za-z0-9_]*\s*=", line):
break
collected.append(line)
source = "\n".join(collected).rstrip()
try:
ast.parse(source)
return source
except SyntaxError:
fixed_lines = []
for line in collected:
fixed_lines.append(line)
candidate = "\n".join(fixed_lines).rstrip()
try:
ast.parse(candidate)
source = candidate
except SyntaxError:
pass
return source if source.strip().startswith("def ") else ""
def syntax_ok(source):
try:
ast.parse(source)
return True, ""
except SyntaxError as e:
return False, str(e)
FORBIDDEN_NAMES = {
"eval", "exec", "compile", "open", "input", "__import__",
"globals", "locals", "vars", "dir", "getattr", "setattr", "delattr",
"help", "breakpoint", "exit", "quit"
}
FORBIDDEN_NODES = (
ast.Import,
ast.ImportFrom,
ast.Global,
ast.Nonlocal,
ast.With,
ast.AsyncWith,
ast.AsyncFunctionDef,
ast.ClassDef,
ast.Delete,
ast.Raise,
)
ALLOWED_BUILTINS = {
"abs": abs,
"all": all,
"any": any,
"bool": bool,
"dict": dict,
"enumerate": enumerate,
"float": float,
"int": int,
"isinstance": isinstance,
"len": len,
"list": list,
"map": map,
"max": max,
"min": min,
"pow": pow,
"range": range,
"reversed": reversed,
"round": round,
"set": set,
"sorted": sorted,
"str": str,
"sum": sum,
"tuple": tuple,
"zip": zip,
}
def static_safety_check(source):
try:
tree = ast.parse(source)
except SyntaxError as e:
return False, f"SyntaxError: {e}"
for node in ast.walk(tree):
if isinstance(node, FORBIDDEN_NODES):
return False, f"Forbidden AST node: {type(node).__name__}"
if isinstance(node, ast.Name):
if node.id in FORBIDDEN_NAMES or node.id.startswith("__"):
return False, f"Forbidden name: {node.id}"
if isinstance(node, ast.Attribute):
if node.attr.startswith("__"):
return False, f"Forbidden attribute: {node.attr}"
if isinstance(node, ast.Call):
if isinstance(node.func, ast.Name) and node.func.id in FORBIDDEN_NAMES:
return False, f"Forbidden call: {node.func.id}"
return True, "passed"
def _worker_run_tests(source, function_name, tests, queue):
try:
safe_globals = {"__builtins__": ALLOWED_BUILTINS}
safe_locals = {}
compiled = compile(source, "", "exec")
exec(compiled, safe_globals, safe_locals)
fn = safe_locals.get(function_name) or safe_globals.get(function_name)
if fn is None:
queue.put({"ok": False, "error": f"{function_name} not found", "passed": 0, "total": len(tests)})
return
passed = 0
details = []
for test in tests:
args = test.get("args", [])
kwargs = test.get("kwargs", {})
expected = test["expected"]
result = fn(*args, **kwargs)
ok = result == expected
passed += int(ok)
details.append({
"args": args,
"kwargs": kwargs,
"expected": expected,
"result": result,
"ok": ok,
})
queue.put({"ok": passed == len(tests), "error": "", "passed": passed, "total": len(tests), "details": details})
except Exception as e:
queue.put({"ok": False, "error": repr(e), "passed": 0, "total": len(tests)})
def run_unit_tests_safely(source, function_name, tests, timeout_seconds=3):
safe, reason = static_safety_check(source)
if not safe:
return {"ok": False, "error": reason, "passed": 0, "total": len(tests), "details": []}
ctx = mp.get_context("fork")
queue = ctx.Queue()
process = ctx.Process(target=_worker_run_tests, args=(source, function_name, tests, queue))
process.start()
process.join(timeout_seconds)
if process.is_alive():
process.terminate()
process.join()
return {"ok": False, "error": "timeout", "passed": 0, "total": len(tests), "details": []}
if queue.empty():
return {"ok": False, "error": "no result returned", "passed": 0, "total": len(tests), "details": []}
return queue.get()
def code_complexity(source):
try:
blocks = cc_visit(source)
if not blocks:
return 1
return max(block.complexity for block in blocks)
except Exception:
return None
def score_candidate(source, test_result):
syntax_score = 1 if syntax_ok(source)[0] else 0
safety_score = 1 if static_safety_check(source)[0] else 0
passed = test_result.get("passed", 0)
total = max(test_result.get("total", 1), 1)
test_score = passed / total
complexity = code_complexity(source)
complexity_penalty = 0 if complexity is None else min(complexity / 20, 0.25)
return syntax_score + safety_score + 3 * test_score - complexity_penalty
Credit: Source link

























