"""Metric computation utilities."""
import os
from typing import Dict, List, Set, Tuple
from ..core import ByteInterval, intersect_size, length
[docs]
def coverage_precision(pred_size: float, gold_size: float, inter_size: float) -> Tuple[float, float]:
"""Compute (coverage, precision). Edge cases: gold=0→cov=1.0, pred=0→prec=1.0."""
cov = inter_size / gold_size if gold_size > 0 else 1.0
prec = inter_size / pred_size if pred_size > 0 else 1.0
return cov, prec
[docs]
def span_total_bytes(spans_by_file: Dict[str, List[ByteInterval]]) -> int:
"""Total bytes across all files."""
return sum(length(intervals) for intervals in spans_by_file.values())
[docs]
def span_intersection_bytes(a: Dict[str, List[ByteInterval]], b: Dict[str, List[ByteInterval]]) -> int:
"""Total intersection bytes across all files."""
total = 0
for f in set(a.keys()) | set(b.keys()):
total += intersect_size(a.get(f, []), b.get(f, []))
return total
# Line-level metrics (replacing span metrics)
LineInterval = Tuple[int, int] # (start_line, end_line) inclusive
def _merge_line_intervals(intervals: List[LineInterval]) -> List[LineInterval]:
"""Merge overlapping or adjacent line intervals."""
if not intervals:
return []
sorted_intervals = sorted(intervals)
merged = [sorted_intervals[0]]
for current in sorted_intervals[1:]:
last = merged[-1]
if current[0] <= last[1] + 1:
merged[-1] = (last[0], max(last[1], current[1]))
else:
merged.append(current)
return merged
def _line_interval_length(intervals: List[LineInterval]) -> int:
"""Total lines covered by intervals."""
return sum(end - start + 1 for start, end in _merge_line_intervals(intervals))
def _line_interval_intersect(a: List[LineInterval], b: List[LineInterval]) -> List[LineInterval]:
"""Intersection of two line interval lists."""
a_m, b_m = _merge_line_intervals(a), _merge_line_intervals(b)
result = []
i, j = 0, 0
while i < len(a_m) and j < len(b_m):
overlap = (max(a_m[i][0], b_m[j][0]), min(a_m[i][1], b_m[j][1]))
if overlap[0] <= overlap[1]:
result.append(overlap)
if a_m[i][1] < b_m[j][1]:
i += 1
elif b_m[j][1] < a_m[i][1]:
j += 1
else:
i += 1
j += 1
return result
def _line_interval_intersect_size(a: List[LineInterval], b: List[LineInterval]) -> int:
"""Lines in intersection."""
return _line_interval_length(_line_interval_intersect(a, b))
def line_total_lines(lines_by_file: Dict[str, List[LineInterval]]) -> int:
"""Total lines across all files."""
return sum(_line_interval_length(intervals) for intervals in lines_by_file.values())
def line_intersection_lines(a: Dict[str, List[LineInterval]], b: Dict[str, List[LineInterval]]) -> int:
"""Total intersection lines across all files."""
total = 0
for f in set(a.keys()) | set(b.keys()):
total += _line_interval_intersect_size(a.get(f, []), b.get(f, []))
return total
# Line-level metrics (replacing span/byte-level)
LineInterval = Tuple[int, int] # (start_line, end_line) inclusive
def line_total_lines(lines_by_file: Dict[str, List[LineInterval]]) -> int:
"""Total lines across all files."""
return sum(_line_interval_length(intervals) for intervals in lines_by_file.values())
def line_intersection_lines(a: Dict[str, List[LineInterval]], b: Dict[str, List[LineInterval]]) -> int:
"""Total intersection lines across all files."""
total = 0
for f in set(a.keys()) | set(b.keys()):
total += _line_interval_intersect_size(a.get(f, []), b.get(f, []))
return total
[docs]
def compute_granularity_metrics(
pred_files: Set[str],
pred_defs: Set[Tuple[str, str, int, int]],
pred_spans: Dict[str, List[ByteInterval]],
gold_files: Set[str],
gold_defs: Set[Tuple[str, str, int, int]],
gold_spans: Dict[str, List[ByteInterval]],
pred_lines: Dict[str, List[LineInterval]] = None,
gold_lines: Dict[str, List[LineInterval]] = None
) -> dict:
"""Compute metrics at all granularities."""
# File
file_inter = len(pred_files & gold_files)
file_cov, file_prec = coverage_precision(len(pred_files), len(gold_files), file_inter)
# Def
def_inter = len(pred_defs & gold_defs)
def_cov, def_prec = coverage_precision(len(pred_defs), len(gold_defs), def_inter)
# Span (byte-level, kept for compatibility)
span_pred = span_total_bytes(pred_spans)
span_gold = span_total_bytes(gold_spans)
span_inter = span_intersection_bytes(pred_spans, gold_spans)
span_cov, span_prec = coverage_precision(span_pred, span_gold, span_inter)
# Line (line-level, new metric)
if pred_lines is None or gold_lines is None:
# If not provided, return empty line metrics
line_pred = 0
line_gold = 0
line_inter = 0
line_cov, line_prec = 0.0, 0.0
else:
line_pred = line_total_lines(pred_lines)
line_gold = line_total_lines(gold_lines)
line_inter = line_intersection_lines(pred_lines, gold_lines)
line_cov, line_prec = coverage_precision(line_pred, line_gold, line_inter)
return {
"file": {"coverage": file_cov, "precision": file_prec,
"intersection": file_inter, "gold_size": len(gold_files), "pred_size": len(pred_files)},
"symbol": {"coverage": def_cov, "precision": def_prec,
"intersection": def_inter, "gold_size": len(gold_defs), "pred_size": len(pred_defs)},
"span": {"coverage": span_cov, "precision": span_prec,
"intersection": span_inter, "gold_size": span_gold, "pred_size": span_pred},
"line": {"coverage": line_cov, "precision": line_prec,
"intersection": line_inter, "gold_size": line_gold, "pred_size": line_pred}
}
def _step_to_line_intervals(step, repo_dir: str) -> Dict[str, List[LineInterval]]:
"""Convert step spans to line intervals."""
result = {}
for span in step.spans:
f = span.get('file')
if not f:
continue
start_line = span.get('start_line', 1)
end_line = span.get('end_line', 1)
if start_line > 0 and end_line > 0:
result.setdefault(f, []).append((start_line, end_line))
# Merge overlapping intervals per file
for f in result:
result[f] = _merge_line_intervals(result[f])
return result
[docs]
def compute_trajectory_metrics(
steps, # List[Step]
gold_files: Set[str],
gold_symbols: Set[Tuple[str, str, int, int]],
gold_spans: Dict[str, List[ByteInterval]],
repo_dir: str,
gold_lines: Dict[str, List[LineInterval]] = None
) -> dict:
"""Compute AUC-Coverage, Redundancy, and per-step metrics."""
from ..extractors import extract_def_set_in_spans, extract_def_set_from_symbol_names
from ..core.intervals import merge
T = len(steps)
if T == 0:
return {
"steps": [],
"auc_coverage": {"file": 0.0, "symbol": 0.0, "span": 0.0, "line": 0.0},
"redundancy": {"file": 0.0, "symbol": 0.0, "span": 0.0, "line": 0.0}
}
# gold_lines should be provided by caller (from gold.line_spans_init())
# If not provided, use empty dict (will result in 0 line metrics)
if gold_lines is None:
gold_lines = {}
union_files, union_symbols, union_spans, union_lines = set(), set(), {}, {}
sum_files, sum_symbols, sum_spans, sum_lines = 0, 0, 0, 0
per_step_metrics = []
for t, step in enumerate(steps):
# Convert step to representations
step_files = set(step.files)
step_lines = _step_to_line_intervals(step, repo_dir)
step_spans = _step_to_byte_spans(step, repo_dir) # Still needed for symbol extraction
if getattr(step, "symbols", None):
step_symbols = extract_def_set_from_symbol_names(step.symbols, repo_dir)
else:
step_symbols = extract_def_set_in_spans(step_spans, repo_dir)
# Update unions
union_files |= step_files
union_symbols |= step_symbols
for f, ivs in step_spans.items():
union_spans[f] = merge(union_spans.get(f, []) + ivs)
for f, ivs in step_lines.items():
union_lines[f] = _merge_line_intervals(union_lines.get(f, []) + ivs)
# Coverage at this step
file_cov = len(union_files & gold_files) / len(gold_files) if gold_files else 1.0
symbol_cov = len(union_symbols & gold_symbols) / len(gold_symbols) if gold_symbols else 1.0
span_inter = span_intersection_bytes(union_spans, gold_spans)
span_gold = span_total_bytes(gold_spans)
span_cov = span_inter / span_gold if span_gold > 0 else 1.0
# Line coverage
if gold_lines:
line_inter = line_intersection_lines(union_lines, gold_lines)
line_gold = line_total_lines(gold_lines)
line_cov = line_inter / line_gold if line_gold > 0 else 1.0
else:
line_cov = 0.0
per_step_metrics.append({
"step": t + 1,
"coverage": {"file": file_cov, "symbol": symbol_cov, "span": span_cov, "line": line_cov}
})
# Accumulate sizes for redundancy
sum_files += len(step_files)
sum_symbols += len(step_symbols)
sum_spans += span_total_bytes(step_spans)
sum_lines += line_total_lines(step_lines)
# AUC (average coverage across steps)
auc_file = sum(s["coverage"]["file"] for s in per_step_metrics) / T
auc_symbol = sum(s["coverage"]["symbol"] for s in per_step_metrics) / T
auc_span = sum(s["coverage"]["span"] for s in per_step_metrics) / T
auc_line = sum(s["coverage"]["line"] for s in per_step_metrics) / T
# Redundancy
red_file = 1 - len(union_files) / sum_files if sum_files > 0 else 0.0
red_symbol = 1 - len(union_symbols) / sum_symbols if sum_symbols > 0 else 0.0
red_span = 1 - span_total_bytes(union_spans) / sum_spans if sum_spans > 0 else 0.0
red_line = 1 - line_total_lines(union_lines) / sum_lines if sum_lines > 0 else 0.0
return {
"steps": per_step_metrics,
"auc_coverage": {"file": auc_file, "symbol": auc_symbol, "span": auc_span, "line": auc_line},
"redundancy": {"file": red_file, "symbol": red_symbol, "span": red_span, "line": red_line}
}
def _step_to_byte_spans(step, repo_dir: str) -> Dict[str, List[ByteInterval]]:
"""Convert step spans to byte intervals."""
from ..core.fileio import line_to_byte
from ..core.intervals import merge
result = {}
for span in step.spans:
f = span.get('file')
if not f:
continue
abs_path = os.path.join(repo_dir, f)
byte_span = line_to_byte(abs_path, span.get('start_line', 1), span.get('end_line', 1))
if byte_span:
result.setdefault(f, []).append(byte_span)
for f in result:
result[f] = merge(result[f])
return result