first commit
This commit is contained in:
parent
f1ddf69649
commit
d872b8c103
|
@ -0,0 +1,36 @@
|
|||
{
|
||||
"_name_or_path": "/jxm/cde/cde-small-v2/checkpoint-2635",
|
||||
"architecture": "transductive",
|
||||
"architectures": [
|
||||
"ContextualDocumentEmbeddingTransformer"
|
||||
],
|
||||
"attn_implementation": null,
|
||||
"auto_map": {
|
||||
"AutoConfig": "model.ContextualModelConfig",
|
||||
"AutoModel": "model.ContextualDocumentEmbeddingTransformer"
|
||||
},
|
||||
"autoregressive_backbone": false,
|
||||
"cache_dir": null,
|
||||
"config_name": null,
|
||||
"dataset_backbone": null,
|
||||
"disable_dropout": true,
|
||||
"disable_transductive_rotary_embedding": true,
|
||||
"embedder": "answerdotai/ModernBERT-base",
|
||||
"embedder_rerank": "sentence-transformers/gtr-t5-base",
|
||||
"embedding_output_dim": null,
|
||||
"limit_layers": null,
|
||||
"limit_layers_first_stage": null,
|
||||
"logit_scale": 50.0,
|
||||
"max_seq_length": 512,
|
||||
"model_revision": "main",
|
||||
"pool_ignore_contextual_tokens": true,
|
||||
"pool_ignore_instruction_tokens": true,
|
||||
"pooling_strategy": "mean",
|
||||
"tokenizer_name": null,
|
||||
"torch_dtype": "float32",
|
||||
"transductive_corpus_size": 512,
|
||||
"transductive_sequence_dropout_prob": 0.0,
|
||||
"transductive_tie_token_embeddings": false,
|
||||
"transductive_tokens_per_document": 1,
|
||||
"transformers_version": "4.48.0.dev0"
|
||||
}
|
|
@ -0,0 +1,13 @@
|
|||
{
|
||||
"__version__": {
|
||||
"sentence_transformers": "3.1.0",
|
||||
"transformers": "4.43.4",
|
||||
"pytorch": "2.5.0.dev20240807+cu121"
|
||||
},
|
||||
"prompts": {
|
||||
"query": "search_query: ",
|
||||
"document": "search_document: "
|
||||
},
|
||||
"default_prompt_name": null,
|
||||
"similarity_fn_name": "cosine"
|
||||
}
|
|
@ -0,0 +1,518 @@
|
|||
from typing import Dict, Iterable, List, Optional, Tuple, Union
|
||||
|
||||
import collections
|
||||
import glob
|
||||
import json
|
||||
import hashlib
|
||||
import itertools
|
||||
import logging
|
||||
import multiprocessing
|
||||
import os
|
||||
import pickle
|
||||
import random
|
||||
import requests
|
||||
import sys
|
||||
import zipfile
|
||||
|
||||
import datasets
|
||||
import numpy as np
|
||||
import torch
|
||||
import tqdm
|
||||
import transformers
|
||||
|
||||
from cde.lib.dist import get_num_proc, get_rank
|
||||
|
||||
|
||||
def get_cde_cache_dir() -> str:
|
||||
script_directory = os.path.normpath(
|
||||
os.path.join(
|
||||
os.path.dirname(os.path.abspath(__file__)),
|
||||
os.pardir, os.pardir,
|
||||
)
|
||||
)
|
||||
return os.path.join(script_directory, "data")
|
||||
|
||||
|
||||
def get_cache_location_from_kwargs(**kwargs):
|
||||
cache_location = os.path.join(
|
||||
get_cde_cache_dir(), "cluster"
|
||||
)
|
||||
os.makedirs(cache_location, exist_ok=True)
|
||||
return os.path.join(cache_location, md5_hash_kwargs(**kwargs))
|
||||
|
||||
|
||||
def process_qrels_uncached(corpus: datasets.Dataset, qrels: datasets.Dataset) -> Tuple[Dict[str, List[float]], Dict[str, List[str]]]:
|
||||
qrels_idxs = collections.defaultdict(list)
|
||||
qrels_scores = collections.defaultdict(list)
|
||||
corpus_ids = np.array(corpus['_id'])
|
||||
skipped_qrels = 0
|
||||
|
||||
for ex in tqdm.tqdm(qrels, desc='processing qrels', colour='#964B00', leave=False):
|
||||
#
|
||||
# example:
|
||||
# {
|
||||
# 'query-id': 1,
|
||||
# 'corpus-id': 'b0680508-2019-04-18T13:48:51Z-00002-000',
|
||||
# 'score': 2
|
||||
# }
|
||||
#
|
||||
q_id = str(ex['query-id'])
|
||||
c_idxs = (corpus_ids == str(ex['corpus-id'])).nonzero()[0]
|
||||
#
|
||||
assert len(c_idxs) <= 1, f"error - duplicate corpus ID? (found {len(c_idxs)} matches)"
|
||||
#
|
||||
if len(c_idxs):
|
||||
qrels_idxs[q_id].append(c_idxs[0])
|
||||
qrels_scores[q_id].append(ex['score'])
|
||||
else:
|
||||
skipped_qrels += 1
|
||||
#
|
||||
|
||||
if skipped_qrels > 0:
|
||||
logging.warning(f'Warning: Skipped {skipped_qrels}/{len(qrels)} qrels.')
|
||||
|
||||
return qrels_idxs, qrels_scores
|
||||
|
||||
|
||||
def process_qrels(
|
||||
corpus: datasets.Dataset, qrels: datasets.Dataset,
|
||||
use_cache: bool = True
|
||||
) -> Tuple[Dict[str, List[float]], Dict[str, List[str]]]:
|
||||
dataset_cache_file = '_'.join(
|
||||
(corpus.cache_files[0]['filename'], qrels.cache_files[0]['filename'])
|
||||
)
|
||||
cache_file = strip_extension(dataset_cache_file) + '_processed_qrels.p'
|
||||
os.makedirs(os.path.dirname(cache_file), exist_ok=True)
|
||||
|
||||
if not (use_cache and os.path.exists(cache_file)):
|
||||
qrels_idxs, qrels_scores = process_qrels_uncached(
|
||||
corpus=corpus, qrels=qrels
|
||||
)
|
||||
if use_cache:
|
||||
pickle.dump((qrels_idxs, qrels_scores), open(cache_file, 'wb'))
|
||||
else:
|
||||
qrels_idxs, qrels_scores = pickle.load(open(cache_file, 'rb'))
|
||||
|
||||
return qrels_idxs, qrels_scores
|
||||
|
||||
|
||||
def strip_extension(filename: str) -> str:
|
||||
"""Strips file extension.
|
||||
|
||||
Ex:
|
||||
>> strip_extension('/root/dir/sub/file.ext')
|
||||
'/root/dir/sub/file'
|
||||
"""
|
||||
return os.path.splitext(filename)[0]
|
||||
|
||||
|
||||
def md5_hash(t: Tuple[str]) -> str:
|
||||
return hashlib.md5('__'.join(t).encode()).hexdigest()
|
||||
|
||||
|
||||
def md5_hash_kwargs(**kwargs) -> str:
|
||||
# We ignore special hf args that start with _ like '__cached__setup_devices'.
|
||||
safe_kwargs = {k: str(v) for k,v in kwargs.items() if not k.startswith('_')}
|
||||
s = json.dumps(safe_kwargs, sort_keys=True)
|
||||
return hashlib.md5(s.encode()).hexdigest()
|
||||
|
||||
def download_url(url: str, save_path: str, chunk_size: int = 1024):
|
||||
"""Download url with progress bar using tqdm
|
||||
https://stackoverflow.com/questions/15644964/python-progress-bar-and-downloads
|
||||
Args:
|
||||
url (str): downloadable url
|
||||
save_path (str): local path to save the downloaded file
|
||||
chunk_size (int, optional): chunking of files. Defaults to 1024.
|
||||
"""
|
||||
r = requests.get(url, stream=True)
|
||||
total = int(r.headers.get('Content-Length', 0))
|
||||
with open(save_path, 'wb') as fd, tqdm.tqdm(
|
||||
desc=save_path,
|
||||
total=total,
|
||||
unit='iB',
|
||||
unit_scale=True,
|
||||
unit_divisor=chunk_size,
|
||||
) as bar:
|
||||
for data in r.iter_content(chunk_size=chunk_size):
|
||||
size = fd.write(data)
|
||||
bar.update(size)
|
||||
|
||||
|
||||
def unzip(zip_file: str, out_dir: str):
|
||||
print("unzipping =>", zip_file)
|
||||
zip_ = zipfile.ZipFile(zip_file, "r")
|
||||
zip_.extractall(path=out_dir)
|
||||
zip_.close()
|
||||
|
||||
|
||||
def download_url_and_unzip(url: str, out_dir: str, chunk_size: int = 1024) -> str:
|
||||
os.makedirs(out_dir, exist_ok=True)
|
||||
dataset = url.split("/")[-1]
|
||||
zip_file = os.path.join(out_dir, dataset)
|
||||
|
||||
if not os.path.isfile(zip_file):
|
||||
logging.info("Downloading {} ...".format(dataset))
|
||||
download_url(url, zip_file, chunk_size)
|
||||
|
||||
if not os.path.isdir(zip_file.replace(".zip", "")):
|
||||
logging.info("Unzipping {} ...".format(dataset))
|
||||
unzip(zip_file, out_dir)
|
||||
|
||||
return os.path.join(out_dir, dataset.replace(".zip", ""))
|
||||
|
||||
|
||||
def tqdm_if_main_worker(iterable: Iterable, **kwargs) -> Iterable:
|
||||
if get_rank() == 0:
|
||||
return tqdm.tqdm(iterable, **kwargs)
|
||||
else:
|
||||
return iterable
|
||||
|
||||
|
||||
class ContextualModelConfig(transformers.configuration_utils.PretrainedConfig):
|
||||
"""We create a dummy configuration class that will just set properties
|
||||
based on whatever kwargs we pass in.
|
||||
|
||||
When this class is initialized (see experiments.py) we pass in the
|
||||
union of all data, model, and training args, all of which should
|
||||
get saved to the config json.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
for key, value in kwargs.items():
|
||||
try:
|
||||
json.dumps(value)
|
||||
setattr(self, key, value)
|
||||
except TypeError:
|
||||
# value was not JSON-serializable, skip
|
||||
continue
|
||||
super().__init__()
|
||||
|
||||
|
||||
def independent_crop(
|
||||
input_ids: torch.Tensor, pad_token_id: int,
|
||||
l1: int = 256, l2: int = 256) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Returns two independent crops from input_ids.
|
||||
|
||||
Assumes input_ids has a beginning and end token, like
|
||||
[101, ..., 102, 0, 0, 0].
|
||||
|
||||
Args:
|
||||
input_ids: tensor of IDs
|
||||
pad_token_id: ID of pad tokens in input_ids
|
||||
l1: length of span 1, cropped
|
||||
l2: length of span 2, cropped
|
||||
Returns:
|
||||
span1: first crop (of length l1)
|
||||
span2: second crop (of length l2)
|
||||
"""
|
||||
# Count tokens until pad.
|
||||
if (input_ids == pad_token_id).sum() == 0:
|
||||
N = len(input_ids)
|
||||
else:
|
||||
N = (input_ids == pad_token_id).int().argmax().item()
|
||||
|
||||
####
|
||||
###
|
||||
##
|
||||
## Contriever: We use the random cropping data
|
||||
## augmentation, with documents of 256 tokens and span
|
||||
## sizes sampled between 5% and 50% of the document
|
||||
## length
|
||||
##
|
||||
###
|
||||
#####
|
||||
####### LaPraDor: The maximum lengths set for queries and
|
||||
####### documents are 64 and 350...
|
||||
#####
|
||||
# TODO is this divide-by-two a good idea? (Don't want s1=s2 ever..)
|
||||
nl1 = min(N//2, l1)
|
||||
nl2 = min(N//2, l2)
|
||||
|
||||
s1_start = random.randint(1, N-nl1)
|
||||
s2_start = random.randint(1, N-nl2)
|
||||
|
||||
s1_idxs = itertools.chain(
|
||||
[0], range(s1_start, s1_start+nl1), [N-1]
|
||||
)
|
||||
s1 = input_ids[torch.tensor(list(s1_idxs))]
|
||||
s2_idxs = itertools.chain(
|
||||
[0], range(s2_start, s2_start+nl2), [N-1]
|
||||
)
|
||||
s2 = input_ids[torch.tensor(list(s2_idxs))]
|
||||
return (s1, s2)
|
||||
|
||||
|
||||
def load_dataset_tables(
|
||||
files: Iterable[str], num_workers: int = 16
|
||||
) -> Iterable[datasets.table.MemoryMappedTable]:
|
||||
import concurrent
|
||||
from multiprocessing import Pool
|
||||
|
||||
# num_workers = min(num_workers, len(files))
|
||||
num_workers = min(32, len(files))
|
||||
|
||||
use_threads = True
|
||||
if use_threads:
|
||||
pool_cls = concurrent.futures.ThreadPoolExecutor
|
||||
pool_kwargs = {"max_workers": num_workers}
|
||||
else:
|
||||
pool_cls = Pool
|
||||
pool_kwargs = {"processes": num_workers}
|
||||
|
||||
with pool_cls(**pool_kwargs) as pool:
|
||||
if len(files) > 10:
|
||||
files = tqdm_if_main_worker(
|
||||
files,
|
||||
desc=f"Loading {len(files)} files with {num_workers} workers",
|
||||
total=len(files),
|
||||
colour="#ffbd88"
|
||||
)
|
||||
|
||||
result = list(
|
||||
pool.map(datasets.table.MemoryMappedTable.from_file, files)
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
def datasets_fast_load_from_disk(cache_path: str) -> datasets.Dataset:
|
||||
logging.info(f"fast_load_from_disk called with path:", cache_path)
|
||||
dataset_info_path = os.path.join(cache_path, "dataset_info.json")
|
||||
with open(dataset_info_path, encoding="utf-8") as dataset_info_file:
|
||||
dataset_info = datasets.DatasetInfo.from_dict(json.load(dataset_info_file))
|
||||
|
||||
dataset_state_path = os.path.join(cache_path, "state.json")
|
||||
with open(dataset_state_path, encoding="utf-8") as state_file:
|
||||
state = json.load(state_file)
|
||||
|
||||
files = glob.glob(os.path.join(cache_path, "data-*.arrow"))
|
||||
files = sorted(files)
|
||||
num_workers = get_num_proc()
|
||||
ds_tables = load_dataset_tables(
|
||||
files=files,
|
||||
num_workers=num_workers
|
||||
)
|
||||
arrow_table = datasets.table.concat_tables(ds_tables)
|
||||
|
||||
split = state["_split"]
|
||||
split = datasets.splits.Split(split) if split is not None else split
|
||||
|
||||
# print("returning dataset")
|
||||
return datasets.Dataset(
|
||||
arrow_table=arrow_table,
|
||||
info=dataset_info,
|
||||
split=split,
|
||||
fingerprint=state["_fingerprint"],
|
||||
)
|
||||
|
||||
|
||||
def tokenize_dataset(
|
||||
dataset: datasets.Dataset,
|
||||
tokenizer: transformers.PreTrainedTokenizer,
|
||||
max_length: int,
|
||||
text_key: str,
|
||||
padding_strategy: str
|
||||
) -> datasets.Dataset:
|
||||
def tokenize_text(ex: Dict) -> Dict:
|
||||
tt = tokenizer(
|
||||
ex[text_key],
|
||||
max_length=max_length,
|
||||
truncation=True,
|
||||
padding=padding_strategy,
|
||||
)
|
||||
for k,v in tt.items():
|
||||
ex[f"{text_key}_{k}"] = v
|
||||
ex["length"] = [len(tt) for tt in ex[f"{text_key}_input_ids"]]
|
||||
return ex
|
||||
|
||||
# generate unique hash for tokenizer
|
||||
vocab = tokenizer.vocab
|
||||
vocab_words = tuple(sorted(vocab.keys(), key=lambda word: vocab[word]))
|
||||
vocab_hash = md5_hash(vocab_words)
|
||||
|
||||
data_fingerprint = '__'.join((
|
||||
dataset._fingerprint, str(vocab_hash), str(max_length),
|
||||
text_key, padding_strategy
|
||||
))
|
||||
data_fingerprint = md5_hash(data_fingerprint)
|
||||
dataset = dataset.map(
|
||||
tokenize_text,
|
||||
new_fingerprint=data_fingerprint,
|
||||
batched=True,
|
||||
load_from_cache_file=True,
|
||||
)
|
||||
return dataset
|
||||
|
||||
|
||||
class TensorRunningAverages:
|
||||
_store_sum: Dict[str, torch.Tensor]
|
||||
_store_total: Dict[str, torch.Tensor]
|
||||
|
||||
def __init__(self):
|
||||
self._store_sum = {}
|
||||
self._store_total = {}
|
||||
|
||||
def __iter__(self) -> Iterable[str]:
|
||||
return iter(self._store_sum.keys())
|
||||
|
||||
def update(self, key: str, val: Union[int, float, torch.Tensor]) -> None:
|
||||
if key not in self._store_sum:
|
||||
self.clear(key)
|
||||
if isinstance(val, torch.Tensor):
|
||||
val = val.item() # tensor -> num
|
||||
self._store_sum[key] += val
|
||||
self._store_total[key] += 1
|
||||
|
||||
def get(self, key: str) -> float:
|
||||
total = max(self._store_total.get(key).item(), 1.0)
|
||||
return (self._store_sum[key] / float(total)).item() or 0.0
|
||||
|
||||
def clear(self, key: str) -> None:
|
||||
self._store_sum[key] = torch.tensor(0.0, dtype=torch.float32)
|
||||
self._store_total[key] = torch.tensor(0, dtype=torch.int32)
|
||||
|
||||
def clear_all(self) -> None:
|
||||
for key in self._store_sum:
|
||||
self.clear(key)
|
||||
|
||||
def get_and_clear_all(self) -> Dict[str, float]:
|
||||
metrics = {}
|
||||
for key in self:
|
||||
metrics[key] = self.get(key)
|
||||
self.clear(key)
|
||||
return metrics
|
||||
|
||||
def load_embedder_and_tokenizer(name: str) -> Tuple[
|
||||
transformers.PreTrainedModel,
|
||||
transformers.PreTrainedTokenizer
|
||||
]:
|
||||
if name.startswith("nomic") or (name == "bert-base-uncased"):
|
||||
from cde.lib.nomic_bert import NomicBertModel
|
||||
if name.endswith("--from-scratch"):
|
||||
name = name.replace("--from-scratch", "")
|
||||
config = transformers.AutoConfig.from_pretrained(name, trust_remote_code=True)
|
||||
model = NomicBertModel._from_config(config)
|
||||
else:
|
||||
model = NomicBertModel.from_pretrained(
|
||||
name, add_pooling_layer=False
|
||||
)
|
||||
tokenizer = transformers.AutoTokenizer.from_pretrained(name)
|
||||
elif name in ["gtr-base", "gtr_base"]:
|
||||
model = transformers.AutoModel.from_pretrained(
|
||||
"sentence-transformers/gtr-t5-base"
|
||||
).encoder
|
||||
tokenizer = transformers.AutoTokenizer.from_pretrained(
|
||||
"sentence-transformers/gtr-t5-base"
|
||||
)
|
||||
elif name == "pile-t5-base-encoder":
|
||||
model = transformers.AutoModel.from_pretrained(
|
||||
"EleutherAI/pile-t5-base"
|
||||
).encoder
|
||||
tokenizer = transformers.AutoTokenizer.from_pretrained(
|
||||
"EleutherAI/pile-t5-base"
|
||||
)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
elif name == "pile-t5-base-decoder":
|
||||
model = transformers.AutoModel.from_pretrained(
|
||||
"EleutherAI/pile-t5-base"
|
||||
).decoder
|
||||
tokenizer = transformers.AutoTokenizer.from_pretrained(
|
||||
"EleutherAI/pile-t5-base"
|
||||
)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
elif name.startswith("gpt2") or name.startswith("meta-llama") or ("Llama" in name):
|
||||
model = transformers.AutoModelForCausalLM.from_pretrained(
|
||||
name,
|
||||
# torch_dtype=torch.bfloat16,
|
||||
attn_implementation="flash_attention_2" if torch.cuda.is_available() else "sdpa",
|
||||
low_cpu_mem_usage=True,
|
||||
# device_map="auto",
|
||||
)
|
||||
model.padding_side = "right"
|
||||
tokenizer = transformers.AutoTokenizer.from_pretrained(name)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
tokenizer.add_eos_token = True
|
||||
elif "Modern" in name:
|
||||
print("special loading for ModernBERT!")
|
||||
# [1] needed for faster training
|
||||
# model = transformers.AutoModel.from_pretrained(name, trust_remote_code=True, reference_compile=True)
|
||||
# [2] needed for non-breaking inference
|
||||
model = transformers.AutoModel.from_pretrained(name, trust_remote_code=True, reference_compile=False)
|
||||
tokenizer = transformers.AutoTokenizer.from_pretrained(name)
|
||||
else:
|
||||
model = transformers.AutoModel.from_pretrained(name, trust_remote_code=True)
|
||||
tokenizer = transformers.AutoTokenizer.from_pretrained(name)
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
def inputs_for_key(inputs: Dict[str, torch.Tensor], key: str):
|
||||
key += "_"
|
||||
return {k.replace(key, ""): v for k,v in inputs.items() if k.startswith(key)}
|
||||
|
||||
|
||||
def count_cpus() -> int:
|
||||
try:
|
||||
return len(os.sched_getaffinity(0))
|
||||
except AttributeError:
|
||||
return multiprocessing.cpu_count()
|
||||
|
||||
|
||||
def shuffle_batches(g: torch.Generator, list_of_tensors: List[torch.Tensor]) -> List[int]:
|
||||
all_indices = []
|
||||
for batch_tensor in tqdm_if_main_worker(list_of_tensors, colour="green", desc="Sampler shuffling per-batch"):
|
||||
rand_perm = torch.randperm(len(batch_tensor), generator=g)
|
||||
batch_list = batch_tensor[rand_perm].tolist()
|
||||
all_indices.extend(batch_list)
|
||||
return all_indices
|
||||
|
||||
|
||||
# def shuffle_batches_multiproc(g: torch.Generator, list_of_tensors: List[torch.Tensor], num_processes: int = 8) -> List[int]:
|
||||
# all_indices = []
|
||||
# print(f"Shuffling {len(list_of_tensors)} tensors with {num_processes} workers.")
|
||||
# pbar = tqdm_if_main_worker(list_of_tensors, colour="orange", desc=f"Sampler shuffling per-batch (nproc={num_processes})")
|
||||
# pool = multiprocessing.Pool(processes=num_processes)
|
||||
# chunk_size = len(list_of_tensors) // num_processes
|
||||
# chunks = [list_of_tensors[i:i + chunk_size] for i in range(0, len(list_of_tensors), chunk_size)]
|
||||
# worker_func = functools.partial(shuffle_batches, g=g)
|
||||
# results = pool.map(worker_func, chunks)
|
||||
# all_indices = []
|
||||
# for result in results:
|
||||
# all_indices.extend(result)
|
||||
# pbar.update()
|
||||
# return all_indices
|
||||
|
||||
|
||||
def exit_if_running_or_finished_wandb(
|
||||
project_name: str,
|
||||
exp_group: str, exp_name: str
|
||||
) -> None:
|
||||
print("Checking if experiment is already running...")
|
||||
import wandb
|
||||
|
||||
api = wandb.Api()
|
||||
running_runs = api.runs(
|
||||
path="cde-0",
|
||||
filters={
|
||||
"display_name": exp_name,
|
||||
"state": {"$regex": "Running|Finished"},
|
||||
"config.exp_group": exp_group,
|
||||
}
|
||||
)
|
||||
print("Found", len(running_runs), f"runs with name {exp_name} and group {exp_group} in {project_name}.")
|
||||
|
||||
if len(running_runs) > 0:
|
||||
print("Exiting because experiment is already running or completed.")
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
HN_FILTER_TOKENIZER_MAP = {
|
||||
"nomic": "nomic-ai/nomic-embed-text-v1",
|
||||
"stella": "dunzhang/stella_en_400M_v5",
|
||||
"sbert": "sentence-transformers/all-MiniLM-L6-v2",
|
||||
"sentence_t5": "sentence-transformers/sentence-t5-base",
|
||||
"gte": "Alibaba-NLP/gte-large-en-v1.5",
|
||||
}
|
||||
def load_hn_filter_tokenizer(tokenizer_name: str) -> Optional[transformers.PreTrainedTokenizer]:
|
||||
if tokenizer_name in HN_FILTER_TOKENIZER_MAP:
|
||||
return transformers.AutoTokenizer.from_pretrained(HN_FILTER_TOKENIZER_MAP[tokenizer_name])
|
||||
else:
|
||||
return None
|
Binary file not shown.
|
@ -0,0 +1,9 @@
|
|||
[
|
||||
{
|
||||
"idx": 0,
|
||||
"name": "0",
|
||||
"path": "",
|
||||
"type": "sentence_transformers_impl.Transformer",
|
||||
"kwargs": ["dataset_embeddings"]
|
||||
}
|
||||
]
|
|
@ -0,0 +1 @@
|
|||
{}
|
|
@ -0,0 +1,155 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import AutoConfig, AutoModel, AutoTokenizer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Transformer(nn.Module):
|
||||
"""Hugging Face AutoModel to generate token embeddings.
|
||||
Loads the correct class, e.g. BERT / RoBERTa etc.
|
||||
Args:
|
||||
model_name_or_path: Hugging Face models name
|
||||
(https://huggingface.co/models)
|
||||
max_seq_length: Truncate any inputs longer than max_seq_length
|
||||
model_args: Keyword arguments passed to the Hugging Face
|
||||
Transformers model
|
||||
tokenizer_args: Keyword arguments passed to the Hugging Face
|
||||
Transformers tokenizer
|
||||
config_args: Keyword arguments passed to the Hugging Face
|
||||
Transformers config
|
||||
cache_dir: Cache dir for Hugging Face Transformers to store/load
|
||||
models
|
||||
do_lower_case: If true, lowercases the input (independent if the
|
||||
model is cased or not)
|
||||
tokenizer_name_or_path: Name or path of the tokenizer. When
|
||||
None, then model_name_or_path is used
|
||||
backend: Backend used for model inference. Can be `torch`, `onnx`,
|
||||
or `openvino`. Default is `torch`.
|
||||
"""
|
||||
|
||||
save_in_root: bool = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name_or_path: str,
|
||||
model_args: dict[str, Any] | None = None,
|
||||
tokenizer_args: dict[str, Any] | None = None,
|
||||
config_args: dict[str, Any] | None = None,
|
||||
cache_dir: str | None = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
if model_args is None:
|
||||
model_args = {}
|
||||
if tokenizer_args is None:
|
||||
tokenizer_args = {}
|
||||
if config_args is None:
|
||||
config_args = {}
|
||||
|
||||
if not model_args.get("trust_remote_code", False):
|
||||
raise ValueError(
|
||||
"You need to set `trust_remote_code=True` to load this model."
|
||||
)
|
||||
|
||||
self.config = AutoConfig.from_pretrained(model_name_or_path, **config_args, cache_dir=cache_dir)
|
||||
self.auto_model = AutoModel.from_pretrained(model_name_or_path, config=self.config, cache_dir=cache_dir, **model_args)
|
||||
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||
"answerdotai/ModernBERT-base",
|
||||
cache_dir=cache_dir,
|
||||
**tokenizer_args,
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"Transformer({self.get_config_dict()}) with Transformer model: {self.auto_model.__class__.__name__} "
|
||||
|
||||
def forward(self, features: dict[str, torch.Tensor], dataset_embeddings: Optional[torch.Tensor] = None, **kwargs) -> dict[str, torch.Tensor]:
|
||||
"""Returns token_embeddings, cls_token"""
|
||||
# If we don't have embeddings, then run the 1st stage model.
|
||||
# If we do, then run the 2nd stage model.
|
||||
if dataset_embeddings is None:
|
||||
sentence_embedding = self.auto_model.first_stage_model(
|
||||
input_ids=features["input_ids"],
|
||||
attention_mask=features["attention_mask"],
|
||||
)
|
||||
else:
|
||||
sentence_embedding = self.auto_model.second_stage_model(
|
||||
input_ids=features["input_ids"],
|
||||
attention_mask=features["attention_mask"],
|
||||
dataset_embeddings=dataset_embeddings,
|
||||
)
|
||||
|
||||
features["sentence_embedding"] = sentence_embedding
|
||||
return features
|
||||
|
||||
def get_word_embedding_dimension(self) -> int:
|
||||
return self.auto_model.config.hidden_size
|
||||
|
||||
def tokenize(
|
||||
self, texts: list[str] | list[dict] | list[tuple[str, str]], padding: str | bool = True
|
||||
) -> dict[str, torch.Tensor]:
|
||||
"""Tokenizes a text and maps tokens to token-ids"""
|
||||
output = {}
|
||||
if isinstance(texts[0], str):
|
||||
to_tokenize = [texts]
|
||||
elif isinstance(texts[0], dict):
|
||||
to_tokenize = []
|
||||
output["text_keys"] = []
|
||||
for lookup in texts:
|
||||
text_key, text = next(iter(lookup.items()))
|
||||
to_tokenize.append(text)
|
||||
output["text_keys"].append(text_key)
|
||||
to_tokenize = [to_tokenize]
|
||||
else:
|
||||
batch1, batch2 = [], []
|
||||
for text_tuple in texts:
|
||||
batch1.append(text_tuple[0])
|
||||
batch2.append(text_tuple[1])
|
||||
to_tokenize = [batch1, batch2]
|
||||
|
||||
max_seq_length = self.config.max_seq_length
|
||||
output.update(
|
||||
self.tokenizer(
|
||||
*to_tokenize,
|
||||
padding=padding,
|
||||
truncation="longest_first",
|
||||
return_tensors="pt",
|
||||
max_length=max_seq_length,
|
||||
)
|
||||
)
|
||||
return output
|
||||
|
||||
def get_config_dict(self) -> dict[str, Any]:
|
||||
return {}
|
||||
|
||||
def save(self, output_path: str, safe_serialization: bool = True) -> None:
|
||||
self.auto_model.save_pretrained(output_path, safe_serialization=safe_serialization)
|
||||
self.tokenizer.save_pretrained(output_path)
|
||||
|
||||
with open(os.path.join(output_path, "sentence_bert_config.json"), "w") as fOut:
|
||||
json.dump(self.get_config_dict(), fOut, indent=2)
|
||||
|
||||
@classmethod
|
||||
def load(cls, input_path: str) -> Transformer:
|
||||
sbert_config_path = os.path.join(input_path, "sentence_bert_config.json")
|
||||
if not os.path.exists(sbert_config_path):
|
||||
return cls(model_name_or_path=input_path)
|
||||
|
||||
with open(sbert_config_path) as fIn:
|
||||
config = json.load(fIn)
|
||||
# Don't allow configs to set trust_remote_code
|
||||
if "model_args" in config and "trust_remote_code" in config["model_args"]:
|
||||
config["model_args"].pop("trust_remote_code")
|
||||
if "tokenizer_args" in config and "trust_remote_code" in config["tokenizer_args"]:
|
||||
config["tokenizer_args"].pop("trust_remote_code")
|
||||
if "config_args" in config and "trust_remote_code" in config["config_args"]:
|
||||
config["config_args"].pop("trust_remote_code")
|
||||
return cls(model_name_or_path=input_path, **config)
|
Loading…
Reference in New Issue