-
Notifications
You must be signed in to change notification settings - Fork 5
Expand file tree
/
Copy pathbfs.py
More file actions
228 lines (211 loc) · 9.9 KB
/
bfs.py
File metadata and controls
228 lines (211 loc) · 9.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
# LICENSE HEADER MANAGED BY add-license-header
#
# /*
# * Copyright (c) 2023, Salesforce, Inc.
# * SPDX-License-Identifier: Apache-2
# *
# * Licensed under the Apache License, Version 2.0 (the "License");
# * you may not use this file except in compliance with the License.
# * You may obtain a copy of the License at
# *
# * http://www.apache.org/licenses/LICENSE-2.0
# *
# * Unless required by applicable law or agreed to in writing, software
# * distributed under the License is distributed on an "AS IS" BASIS,
# * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# * See the License for the specific language governing permissions and
# * limitations under the License.
# */
#
import openai
from utils import enumerate_resume, make_printv, write_jsonl, resume_success_count
from executors import executor_factory
from generators import generator_factory, model_factory
from typing import List, Dict, Tuple, Any
import math
import re
import sys
from collections import Counter
from common import gen_test_eval
# bfs real
sys.set_int_max_str_digits(100000) # Increase the limit to 10000 digits
# Many are passed but not solved, so if one has passed, use the agreement function to select one
ADD_HINT = "To solve the problem, you can refer the hint given by an expert, and complete the details by analyzing it first.\nHint:"
# TODO: From sample to list
class Node:
def __init__(self, solution: str, parent=None, context="", depth=0):
self.solution = solution
self.parent = parent
self.children = []
self.value = 0
self.visits = 0
self.context = ""
self.depth = depth
self.reflection = ""
self.test_feedback = ""
self.strategy=""
def uct(self, exploration_weight=1.0):
if self.visits == 0:
# return float('inf')
return self.value
return (self.value / self.visits) + exploration_weight * math.sqrt(math.log(self.parent.visits) / self.visits)
def best_child(self):
if not self.children: # Check if children list is empty
return None
return max(self.children, key=lambda child: child.uct())
def best_child_value(self):
if not self.children: # Check if children list is empty
return None
return max(self.children, key=lambda child: child.value)
def sort_children_by_value(self):
self.children.sort(key=lambda x: x.value, reverse=True)
def update(self, reward: float):
self.visits += 1
self.value += reward
def rerank_list_of_nodes(list_of_nodes):
return sorted(list_of_nodes, key=lambda x:x.value, reverse=True) # small value in the front
def run_bfs(
dataset: List[dict],
model_name: str,
language: str,
log_path: str,
verbose: bool,
max_iters: int,
is_leetcode: bool = False,
max_depth: int = 3,
search_width: int = 3,
Codecontests: bool = False
) -> None:
if Codecontests: exe = executor_factory("code_contests")
else: exe = executor_factory(language, is_leet=is_leetcode)
gen = generator_factory(language)
model = model_factory(model_name)
print_v = make_printv(verbose)
count, sad_case, debug_thoroughed_case, enter_debug_case = 0, 0, 0, 0
num_items = len(dataset)
num_success, weak_success = 0, 0 # Counter for successful solutions
passed_at_sample, solve_or_not = [], []
debug_case, skip = 0, 0
pass_problem_subset = []
for idx, item in enumerate(dataset):
tests_i = item["given_tests"]
if Codecontests:
item["entry_point"] = ""
else:
tests_i = [test for test in tests_i if item['entry_point'] in test and 'assert False' not in test]
hints = gen.strategy(item["prompt"], model, num_strategy=search_width, task="strategy", temperature=0)
if len(hints) > search_width: hints = hints[:search_width]
stack, memory_stack = [], []
is_solved, is_weaker_solved = False, False
num_try = 0
if len(hints) < search_width:
count += 1
for hint in hints:
cur_func_impl = gen.func_impl(item["prompt"] + f"{ADD_HINT} {hint}\n", model, "simple",
temperature=0)
new_node = Node(cur_func_impl)
num_try += 1
is_passing, feedback, reward = gen_test_eval(exe, cur_func_impl, tests_i, prev=item["prev"])
if is_passing:
is_solved = exe.evaluate(
item["entry_point"], cur_func_impl, item["test"], timeout=1, prev=item["prev"]) # early exit
if "weaker_test" in item.keys():
is_weaker_solved = exe.evaluate(
item["entry_point"], cur_func_impl, item["weaker_test"], timeout=1, prev=item["prev"])
break
new_node.test_feedback = feedback
new_node.update(reward)
new_node.strategy = hint
stack.append(new_node)
# Exit when passed public test cases.
if is_passing:
if is_solved:
num_success += int(is_solved)
passed_at_sample.append(num_try)
if "difficulty" in item.keys(): pass_problem_subset.append(item["difficulty"])
else:
print("SAD, passed but not solved.")
sad_case += 1
if is_weaker_solved:
weak_success += int(is_weaker_solved)
item["weak_acc"] = round(weak_success / (idx + 1), 3)
item["acc"] = round(num_success / (idx + 1), 3)
write_jsonl(log_path, [item], append=True)
print_v(f'completed {idx + 1}/{num_items}: acc = {round(num_success / (idx + 1), 3)}, weak_acc={item["weak_acc"]}, pass no solve: {sad_case}, enter debug: {debug_case}')
continue # early stop on this case if passsed
print("Entering Debugging Stage")
debuged = True; debug_case += 1
stack = rerank_list_of_nodes(stack) # out of all stack
print("Stack after sorting: ", [a.value for a in stack])
while stack and num_try < max_iters and not is_passing:
this_node = stack.pop(0)
if this_node.depth >= max_depth: continue
this_node.visits += 1
reflections = gen.strategy(item["prompt"],
model, task="reflection",
num_strategy=search_width,
prev_func_impl=this_node.solution,
feedback=this_node.test_feedback,
temperature=0,
given_strategy=this_node.strategy)
if len(reflections) < 2: "print not enough reflections!"
for reflection in reflections:
if num_try >= max_iters: break
new_solution, while_cnt = None, 0
while new_solution is None and while_cnt < 3:
while_cnt += 1
new_solution = gen.func_impl(
func_sig=item["prompt"],
model=model,
strategy="reflexion",
prev_func_impl=this_node.solution,
feedback=this_node.test_feedback,
self_reflection=reflection,
temperature=0
)
is_passing, feedback, reward = gen_test_eval(exe, new_solution, tests_i, prev=item["prev"])
num_try += 1
if is_passing:
is_solved = exe.evaluate(
item["entry_point"], new_solution, item["test"], timeout=1, prev=item["prev"])
if "weaker_test" in item.keys():
is_weaker_solved = exe.evaluate(
item["entry_point"], new_solution, item["weaker_test"], timeout=1, prev=item["prev"])
break
new_node = Node(new_solution, depth=this_node.depth + 1)
new_node.test_feedback = feedback
new_node.update(reward)
new_node.strategy = this_node.strategy
this_node.children.append(new_node)
if is_passing: break
this_node.sort_children_by_value()
stack.extend(this_node.children)
print("Children after sorting: ", [a.value for a in stack])
if num_try >= max_iters: debug_thoroughed_case += 1
if is_passing:
if debuged: enter_debug_case += 1
if is_solved:
num_success += int(is_solved)
passed_at_sample.append(num_try)
if "difficulty" in item.keys(): pass_problem_subset.append(item["difficulty"])
else:
sad_case += 1
print("Sad, pass but not solve")
if is_weaker_solved:
weak_success += int(is_weaker_solved)
item["weak_acc"] = round(weak_success / (idx + 1), 3)
item["acc"] = round(num_success / (idx + 1), 3)
write_jsonl(log_path, [item], append=True)
print_v(f'completed {idx + 1}/{num_items}: acc = {round(num_success / (idx + 1), 3)}, weak_acc={item["weak_acc"]}, pass no solve: {sad_case}, enter debug: {debug_case}')
continue
print("_______________________________")
print(passed_at_sample)
print(sorted(passed_at_sample))
print(len(passed_at_sample))
print(Counter(passed_at_sample))
print("Passed but not solved case", sad_case)
print("not sample 2 even when asked: ", count)
print("20 tries used still not solve:", debug_thoroughed_case)
print("Pass not solve after debugging", enter_debug_case)
print(Counter(pass_problem_subset))
print_v(f'completed {idx + 1}/{num_items}: acc = {round(num_success / (idx + 1), 3)}, weak_acc={round(weak_success / (idx + 1), 3)}')