Skip to content

dspy.BootstrapFinetune

dspy.BootstrapFinetune(metric: Optional[Callable] = None, multitask: bool = True, train_kwargs: Optional[Union[Dict[str, Any], Dict[LM, Dict[str, Any]]]] = None, adapter: Optional[Union[Adapter, Dict[LM, Adapter]]] = None, exclude_demos: bool = False, num_threads: int = 6)

Bases: FinetuneTeleprompter

Source code in dspy/teleprompt/bootstrap_finetune.py
def __init__(
    self,
    metric: Optional[Callable] = None,
    multitask: bool = True,
    train_kwargs: Optional[Union[Dict[str, Any], Dict[LM, Dict[str, Any]]]] = None,
    adapter: Optional[Union[Adapter, Dict[LM, Adapter]]] = None,
    exclude_demos: bool = False,
    num_threads: int = 6,
):
    # TODO(feature): Inputs train_kwargs (a dict with string keys) and
    # adapter (Adapter) can depend on the LM they are used with. We are
    # takingthese as parameters for the time being. However, they can be 
    # attached to LMs themselves -- an LM could know which adapter it should
    # be used with along with the train_kwargs. This will lead the only
    # required argument for LM.finetune() to be the train dataset.

    super().__init__(train_kwargs=train_kwargs)
    self.metric = metric
    self.multitask = multitask
    self.adapter: Dict[LM, Adapter] = self.convert_to_lm_dict(adapter)
    self.exclude_demos = exclude_demos
    self.num_threads = num_threads

Functions

compile(student: Program, trainset: List[Example], teacher: Optional[Program] = None) -> Program

Source code in dspy/teleprompt/bootstrap_finetune.py
def compile(self, student: Program, trainset: List[Example], teacher: Optional[Program] = None) -> Program:
    # TODO: Print statements can be converted to logger.info if we ensure
    # that the default DSPy logger logs info level messages in notebook
    # environments.
    print("[BootstrapFinetune] Preparing the student and teacher programs...")
    student = prepare_student(student)
    teachers = teacher if isinstance(teacher, list) else [teacher]
    teachers = [prepare_teacher(student, teacher) for teacher in teachers]
    set_missing_predictor_lms(student)

    print("[BootstrapFinetune] Bootstrapping data...")
    trace_data = []

    for teacher in teachers:
        set_missing_predictor_lms(teacher)
        trace_data  = bootstrap_trace_data(program=teacher, dataset=trainset, metric=self.metric, num_threads=self.num_threads)

    print("[BootstrapFinetune] Preparing the train data...")
    key_to_data = {}
    for pred_ind, pred in enumerate(student.predictors()):
        data_pred_ind = None if self.multitask else pred_ind
        training_key = (pred.lm, data_pred_ind)
        if training_key not in key_to_data:
            train_data, data_format = self._prepare_finetune_data(trace_data=trace_data, lm=pred.lm, pred_ind=data_pred_ind)
            print(f"[BootstrapFinetune] Using {len(train_data)} data points for fine-tuning the model: {pred.lm.model}")
            finetune_kwargs = dict(lm=pred.lm, train_data=train_data, train_kwargs=self.train_kwargs[pred.lm], data_format=data_format)
            key_to_data[training_key] = finetune_kwargs

    print("[BootstrapFinetune] Starting LM fine-tuning...")
    # TODO(feature): We could run batches of fine-tuning jobs in sequence
    # to avoid exceeding the number of threads.
    err = f"BootstrapFinetune requires `num_threads` to be bigger than or equal to the number of fine-tuning jobs. There are {len(key_to_data)} fine-tuning jobs to start, but the number of threads is: {self.num_threads}! If the `multitask` flag is set to False, the number of fine-tuning jobs will be equal to the number of predictors in the student program. If the `multitask` flag is set to True, the number of fine-tuning jobs will be equal to: 1 if there is only a context LM, or the number of unique LMs attached to the predictors in the student program. In any case, the number of fine-tuning jobs will be less than or equal to the number of predictors."
    assert len(key_to_data) <= self.num_threads, err
    print(f"[BootstrapFinetune] {len(key_to_data)} fine-tuning job(s) to start")
    key_to_lm = self.finetune_lms(key_to_data)

    print("[BootstrapFinetune] Updating the student program with the fine-tuned LMs...")
    for pred_ind, pred in enumerate(student.predictors()):
        data_pred_ind = None if self.multitask else pred_ind
        training_key = (pred.lm, data_pred_ind)
        pred.lm = key_to_lm[training_key]
        # TODO: What should the correct behavior be here? Should
        # BootstrapFinetune modify the prompt demos according to the 
        # train data?
        pred.demos = [] if self.exclude_demos else pred.demos

    print("[BootstrapFinetune] BootstrapFinetune has finished compiling the student program")
    student._compiled = True
    return student

finetune_lms(finetune_dict) -> Dict[Any, LM] staticmethod

Source code in dspy/teleprompt/bootstrap_finetune.py
@staticmethod
def finetune_lms(finetune_dict) -> Dict[Any, LM]:
    num_jobs = len(finetune_dict)
    print(f"[BootstrapFinetune] Starting {num_jobs} fine-tuning job(s)...")
    # TODO(nit) Pass an identifier to the job so that we can tell the logs
    # coming from different fine-tune threads.

    key_to_job = {}
    for key, finetune_kwargs in finetune_dict.items():
        lm = finetune_kwargs.pop("lm")
        key_to_job[key] = lm.finetune(**finetune_kwargs)

    key_to_lm = {}
    for ind, (key, job) in enumerate(key_to_job.items()):
        key_to_lm[key] = job.result()
        job.thread.join()
        print(f"[BootstrapFinetune] Job {ind   1}/{num_jobs} is done")

    return key_to_lm