def rollout(self, *, task, context, reset_env=True):
self.reset(task=task, context=context, reset_env=reset_env)
while True:
messages, reward, done, info = self.step()
if done:
break
return messages, reward, done, info
def reset(self, task, context="", reset_env=True):
...
skills = self.skill_manager.retrieve_skills(query=self.context)
...
system_message = self.action_agent.render_system_message(skills=skills)
human_message = self.action_agent.render_human_message(
events=events, code="", task=self.task, context=context, critique=""
)
self.messages = [system_message, human_message]
...
return self.messages
def step(self):
...
ai_message = self.action_agent.llm(self.messages)
...
self.conversations.append(
(self.messages[0].content, self.messages[1].content, ai_message.content)
)
parsed_result = self.action_agent.process_ai_message(message=ai_message)
success = False
if isinstance(parsed_result, dict):
code = parsed_result["program_code"] + "\n" + parsed_result["exec_code"]
events = self.env.step(
code,
programs=self.skill_manager.programs,
)
...
success, critique = self.critic_agent.check_task_success(
...
)
if self.reset_placed_if_failed and not success:
...
new_skills = self.skill_manager.retrieve_skills(
query=self.context
+ "\n\n"
+ self.action_agent.summarize_chatlog(events)
)
system_message = self.action_agent.render_system_message(skills=new_skills)
human_message = self.action_agent.render_human_message(
events=events,
code=parsed_result["program_code"],
task=self.task,
context=self.context,
critique=critique,
)
self.last_events = copy.deepcopy(events)
self.messages = [system_message, human_message]
else:
...
...
done = (
self.action_agent_rollout_num_iter >= self.action_agent_task_max_retries
or success
)
...
return self.messages, 0, done, info