def compile(self, student: Program, trainset: List[Example], teacher: Optional[Program] = None) -> Program:
# TODO: Print statements can be converted to logger.info if we ensure
# that the default DSPy logger logs info level messages in notebook
# environments.
print("[BootstrapFinetune] Preparing the student and teacher programs...")
student = prepare_student(student)
teachers = teacher if isinstance(teacher, list) else [teacher]
teachers = [prepare_teacher(student, teacher) for teacher in teachers]
set_missing_predictor_lms(student)
print("[BootstrapFinetune] Bootstrapping data...")
trace_data = []
for teacher in teachers:
set_missing_predictor_lms(teacher)
trace_data = bootstrap_trace_data(program=teacher, dataset=trainset, metric=self.metric, num_threads=self.num_threads)
print("[BootstrapFinetune] Preparing the train data...")
key_to_data = {}
for pred_ind, pred in enumerate(student.predictors()):
data_pred_ind = None if self.multitask else pred_ind
training_key = (pred.lm, data_pred_ind)
if training_key not in key_to_data:
train_data, data_format = self._prepare_finetune_data(trace_data=trace_data, lm=pred.lm, pred_ind=data_pred_ind)
print(f"[BootstrapFinetune] Using {len(train_data)} data points for fine-tuning the model: {pred.lm.model}")
finetune_kwargs = dict(lm=pred.lm, train_data=train_data, train_kwargs=self.train_kwargs[pred.lm], data_format=data_format)
key_to_data[training_key] = finetune_kwargs
print("[BootstrapFinetune] Starting LM fine-tuning...")
# TODO(feature): We could run batches of fine-tuning jobs in sequence
# to avoid exceeding the number of threads.
err = f"BootstrapFinetune requires `num_threads` to be bigger than or equal to the number of fine-tuning jobs. There are {len(key_to_data)} fine-tuning jobs to start, but the number of threads is: {self.num_threads}! If the `multitask` flag is set to False, the number of fine-tuning jobs will be equal to the number of predictors in the student program. If the `multitask` flag is set to True, the number of fine-tuning jobs will be equal to: 1 if there is only a context LM, or the number of unique LMs attached to the predictors in the student program. In any case, the number of fine-tuning jobs will be less than or equal to the number of predictors."
assert len(key_to_data) <= self.num_threads, err
print(f"[BootstrapFinetune] {len(key_to_data)} fine-tuning job(s) to start")
key_to_lm = self.finetune_lms(key_to_data)
print("[BootstrapFinetune] Updating the student program with the fine-tuned LMs...")
for pred_ind, pred in enumerate(student.predictors()):
data_pred_ind = None if self.multitask else pred_ind
training_key = (pred.lm, data_pred_ind)
pred.lm = key_to_lm[training_key]
# TODO: What should the correct behavior be here? Should
# BootstrapFinetune modify the prompt demos according to the
# train data?
pred.demos = [] if self.exclude_demos else pred.demos
print("[BootstrapFinetune] BootstrapFinetune has finished compiling the student program")
student._compiled = True
return student