[Semantic Scholar] – [Code] – [Tweet] – [Video]
- 2023-08-25: First draft. This paper is a preprint.
Problem Settings
The paper tries to solve the data selection problem for language model pretraining. Suppose a collection of high-quality samples x_ 1′ \cdots, x_ n’ have a distribution p, with another collection of N low-quality samples available; how could we sample a high-quality subset x_ 1, \cdots, x_ k\quad (k \ll N) with an underlying distribution q that approximates p?
Heuristic Classification
This is the approach used by GPT-3, EleutherAI’s Pile dataset, PaLM to select training corpus. Specifically,
- Training a fasttext regression model f: \mathcal{X} \rightarrow [0, 1] using unigrams and bigrams from high-quality datasets.
- 
Applying the trained classifier to the low-quality data collection and sampling x_ i if np.random.pareto(alpha) > 1 - score.alphais chosen as 9 in the GPT-3 paper:
 
  We chose \alpha=9 in order to take mostly documents the classifier scored highly, but still include some documents that were out of distribution.
 
 Samples from the Pareto distributions are small; this makes most of the included samples high-quality. 
Data Selection with Importance Resampling (DSIR)
The issue of the heuristic classification is that it does not explicitly model the underlying distribution q. Instead, the authors of DSIR explicitly model the distribution of a corpus as follows:
p(z; \gamma)=\prod_ {j=1}^{10000} \gamma_ j^{z_ j}
where
- $z$ is a 10000-dimensional vector; its entries represent the index of the hashed unigrams and bigrams (with potential collisions).
- $\gamma$ is the parameter to learn.
After learning the distributions p and q, we could assign a weight to each sample of the pool w_ i = \frac{p(z_ i)}{q(z_ i)},\ i =1, \cdots, N. We could then sample the pool with weights w_ 1, \cdots, w_ N until we have collected k samples. The authors sample the data without replacement and explain this choice theoretically. They could have explained it better, as deduplication is one key aspect for language model pretraining (Falcon paper).
Experiments
- The average KL reduction strongly correlates (r = 0.89) with the accuracies on the downstream task (i.e., GLUE). The DSIR significantly improves the downstream accuracies. This correlation is a post-hoc justification of dataset modeling p(z;\gamma).
The KL reduction is defined as following, where \hat{p} is target distribution, \hat{q} is the raw distribution, and p’ is the distribution of doing some data selection, including random selection, expert curation, and sampling with the proposed algorithm. \frac{1}{\vert \mathcal{T}\vert} \sum_{\hat{p} \sim \mathcal{T}} \mathrm{KL}(\hat{p} \parallel \hat{q}) – \mathrm{KL}(\hat{p} \parallel p’),\quad \mathrm{KL}(p\parallel q)= H(p, q) – H(p) There is \mathcal{T} because the authors are trying to evaluate the data selection methods; these methods could be applied to many different models. Therefore, there will be n scores for n data selection algorithms. 
- 
Continued pretraining on RoBERTa on domain-specific datasets sampled using the DSIR algorithm improves upon the model fine-tuned with datasets sampled with the baseline methods (Table 1, 2). 
- Training BERT from scratch using data sampled with the different sampling approaches and fine-tuning on GLUE shows the proposed selection algorithm’s advantages over the heuristic classification and random selection (Table 4).
- It is important to make sure the domain of the pretraining dataset matches the deployment domain as (1) performance typically drops when the domain is different (Table 3), and (2) domain transfer is hard to predict (Figure 3).
Code
The implementation of the DSIR algorithm is straightforward:
- Obtaining the Count Vector of a String
A count vector of textcould be obtained withget_ngram_info(text)(i.e., z = h(x)). Thehash_bucketsreturns an integer of[0, 9999].
 from nltk import WordPunctTokenizer
tokenizer = WordPunctTokenizer()
def hash_buckets(string, num_buckets=10000):
  return int(abs(hash(string)) % num_buckets)
def get_ngram_info(line, n=2, num_buckets=10000):
  words = tokenizer.tokenize(line.lower())
  # "i love it" be be converted to:
  # - unigrams: ['i', 'love', 'it']
  # - bigrams: [('i', 'love'), ('love', 'it')]
  unigrams, bigrams = words, list(zip(words, islice(words, 1, None)))
  counts = np.zeros(num_buckets, dtype=int)
  for unigram in unigrams:
      counts[hash_buckets(unigram, num_buckets=num_buckets)] += 1
  for bigram in bigrams:
      counts[hash_buckets(bigram, num_buckets=num_buckets)] += 1
  return counts
 
- Obtaining the Importance Weight
The code snippet reads each line from the Pile dataset and assigns a score logratioto this line. Thelogratiowill be used for sampling.
 log_diff_dist = np.log(target_dist + 1e-8) - np.log(pile_dist + 1e-8)
with open(path, 'r') as f:
  lines = f.readlines()
  for k, line in enumerate(tqdm(lines, miniters=1000000, maxinterval=1000000)):
      ex = json.loads(line)
      line = ex["contents"]
      curr_count = get_ngram_info(line, n=ngrams)
      logratio = np.inner(curr_count, log_diff_dist)
      logratios.append(logratio)
 
- Sample Selection
Gumbel distribution is used to model the distribution of maximum values. The authors use it here to make sure the sampling is done without replacement (i.e., Gumbel top-k trick). logratios = logratios
logratios += np.random.gumbel(size=len(logratios))
# select the samples with lowest logratios
chosen_idxs = np.argpartition(-logratios, num_to_retrieve)[:num_to_retrieve]
 
Based on the functions above, we could write a simple unified function select_data_using_dsir(source_df, target_df, num_to_retrieve). This function is applicable to the scenario where both the source and target dataset are not too large and could be processed in memory.
def select_data_using_dsir(source_df, target_df, num_to_retrieve):
    source_texts = source_df.text.tolist()
    target_texts = target_df.text.tolist()
    source_count_vecs = [
        get_ngram_info(source_text) for source_text in tqdm(source_texts, desc="modeling source texts")
    ]
    source_dist = np.sum(source_count_vecs, axis=0)
    target_dist = np.sum(
        [get_ngram_info(target_text) for target_text in tqdm(target_texts, desc="modeling target texts")],
        axis=0
    )
    source_dist = source_dist / source_dist.sum()
    target_dist = target_dist / target_dist.sum()
    log_diff_dist = np.log(target_dist + 1e-8) - np.log(source_dist + 1e-8)
    logratios = np.array([
        np.inner(source_count_vec, log_diff_dist)
        for source_count_vec in tqdm(source_count_vecs, desc="retrieving source texts")
    ])
    logratios += np.random.gumbel(size=len(logratios))
    source_df["score"] = logratios
    chosen_idxs = np.argpartition(-logratios, num_to_retrieve)[:num_to_retrieve]
    return source_df.iloc[chosen_idxs]
source_dataset = load_dataset('yxchar/imdb-tlm', split="train")
target_dataset = load_dataset("yxchar/ag-tlm", split="train")
source_df = pd.DataFrame(source_dataset)
target_df = pd.DataFrame(target_dataset)
selected_source_df = select_data_using_dsir(source_df, target_df, num_to_retrieve=100)