From d5049b651ce6ae6a24305531ac2b836fa5f156cb Mon Sep 17 00:00:00 2001 From: "l.gabrysiak" Date: Tue, 25 Feb 2025 21:52:06 +0100 Subject: [PATCH] mod --- hft.py | 212 ++++++++++++++++++++++----------------------------------- 1 file changed, 82 insertions(+), 130 deletions(-) diff --git a/hft.py b/hft.py index 0a05bb5..764226d 100644 --- a/hft.py +++ b/hft.py @@ -11,12 +11,11 @@ import pytesseract from PIL import Image from collections import defaultdict from huggingface_hub import login -from torch.utils.data import DataLoader # Konfiguracja os.environ['TORCH_USE_CUDA_DSA'] = '1' os.environ["TOKENIZERS_PARALLELISM"] = "false" -login(token="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") # Zastąp swoim tokenem +login(token="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") class SourceMapper: def __init__(self): @@ -98,39 +97,39 @@ def prepare_dataset(directory, catalog_path, source_mapper): print(f"Rozpoznany typ dokumentu: {doc_type}") if doc_type != "Opracowanie własne": - # Ulepszone wyrażenie regularne dla różnych formatów - articles = re.split(r'(?i)(Art[^\S\n]*\.?[^\S\n]*\d+[^\S\n]*\.?)', text) + articles = re.split(r'(?i)(Art[\.\s]+\d+[\.\s]?)', text) articles = [a.strip() for a in articles if a.strip()] - print(f"Znaleziono {len(articles)//2} artykułów") + print(f"Znaleziono {len(articles)} fragmentów") + # Generowanie większej liczby przykładów for i in range(0, len(articles)-1, 2): - article_number = articles[i] - article_content = articles[i+1] - - if len(article_content) < 50: - print(f"Pominięto zbyt krótki artykuł: {article_number}") - continue + for chunk_size in [256, 512, 1024]: # Różne rozmiary chunków + article_number = articles[i] + article_content = articles[i+1] - source = f"{doc_type}, {article_number}" - print(f"Dodano artykuł: {source}") - - source_mapper.add_source(source) - data.append({ - "text": f"{article_number} {article_content}", - "source_idx": source_mapper.get_idx(source) - }) + chunks = [article_content[j:j+chunk_size] for j in range(0, len(article_content), chunk_size//2)] + chunks = [c for c in chunks if len(c) > 100] + + for chunk in chunks: + source = f"{doc_type}, {article_number}" + source_mapper.add_source(source) + data.append({ + "text": f"{article_number} {chunk}", + "source_idx": source_mapper.get_idx(source) + }) else: clean_text = re.sub(r'\s+', ' ', text).strip() - chunks = [clean_text[i:i+512] for i in range(0, len(clean_text), 512)] - chunks = [c for c in chunks if c.strip()] - - for chunk in chunks: - data.append({ - "text": chunk, - "source_idx": -1 - }) - print(f"Dodano {len(chunks)} chunków") + for chunk_size in [256, 512, 768]: # Trzy różne rozmiary + chunks = [clean_text[i:i+chunk_size] for i in range(0, len(clean_text), chunk_size//2)] + chunks = [c for c in chunks if c.strip()] + + for chunk in chunks: + data.append({ + "text": chunk, + "source_idx": -1 + }) + print(f"Dodano {len(chunks)*3} chunków") except Exception as e: print(f"Błąd podczas przetwarzania pliku: {str(e)}") @@ -150,44 +149,32 @@ class CustomModel(nn.Module): def __init__(self, model_name, config): super().__init__() self.base_model = AutoModelForCausalLM.from_pretrained(model_name, config=config) - self.source_embedding = nn.Embedding(1000, config.hidden_size, padding_idx=-1) + self.source_embedding = nn.Embedding(10000, config.hidden_size, padding_idx=-1) - # Zamrożenie warstw bazowego modelu + # Fine-tuning części modelu for param in self.base_model.parameters(): param.requires_grad = False for param in self.base_model.get_output_embeddings().parameters(): param.requires_grad = True + for param in self.base_model.get_input_embeddings().parameters(): + param.requires_grad = True def forward(self, input_ids=None, attention_mask=None, labels=None, source_idx=None, **kwargs): if source_idx is not None: valid_indices = torch.clamp(source_idx, 0, self.source_embedding.num_embeddings-1) - - source_embeds = torch.nn.functional.normalize( - self.source_embedding(valid_indices), - p=2, - dim=-1 - ).unsqueeze(1) - + source_embeds = self.source_embedding(valid_indices).unsqueeze(1) inputs_embeds = self.base_model.get_input_embeddings()(input_ids) + source_embeds - - return self.base_model( - inputs_embeds=inputs_embeds, - attention_mask=attention_mask, - labels=labels, - **kwargs - ) - - return self.base_model( - input_ids=input_ids, - attention_mask=attention_mask, - labels=labels, - **kwargs - ) + return self.base_model(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels, **kwargs) + return self.base_model(input_ids=input_ids, attention_mask=attention_mask, labels=labels, **kwargs) def generate(self, *args, **kwargs): return self.base_model.generate(*args, **kwargs) class CustomTrainer(Trainer): + def __init__(self, *args, **kwargs): + self.tokenizer = kwargs.pop('tokenizer', None) + super().__init__(*args, **kwargs) + def compute_loss(self, model, inputs, return_outputs=False, **kwargs): labels = inputs.pop("labels") source_idx = inputs.pop("source_idx", None) @@ -195,72 +182,54 @@ class CustomTrainer(Trainer): return (outputs.loss, outputs) if return_outputs else outputs.loss def evaluate(self): - val_questions = { - "art1": "Jakie są prawa pracownika według art. 1?", - "art2": "Kto jest pracownikiem według art. 2?", - "art3": "Jakie są obowiązki pracodawcy według art. 3?" - } + questions = [ + "Jakie są prawa pracownika według art. 1?", + "Kto jest pracownikiem według art. 2?", + "Jakie są obowiązki pracodawcy według art. 3?" + ] - model.eval() - results = {} - - for key, question in val_questions.items(): - result = self.generate_answer(question) - results[key] = result - - print("\nWyniki walidacji:") - for key, val in results.items(): - print(f"\n{val_questions[key]}") - print(f"Odpowiedź: {val['answer'][:200]}...") - print(f"Źródła: {val['sources']}") + print("\n" + "="*50 + "\nEWALUACJA\n" + "="*50) + for q in questions: + result = self.generate_answer(q) + print(f"\nPYTANIE: {q}") + print(f"ODPOWIEDŹ: {result['answer'][:500]}") + print(f"ŹRÓDŁA: {', '.join(result['sources'])}") + print("-"*80) return {"loss": 0.0} def generate_answer(self, question): - tokenizer = self.tokenizer - model = self.model - device = model.base_model.device - - prompt = f"[PYTANIE PRAWNE] {question} [KONTEKST]" - - inputs = tokenizer( - prompt, + inputs = self.tokenizer( + f"[PYTANIE] {question} [KONTEKST]", return_tensors="pt", truncation=True, max_length=512 - ).to(device) + ).to(self.model.base_model.device) with torch.no_grad(): - outputs = model.generate( + outputs = self.model.generate( **inputs, - max_new_tokens=150, - temperature=0.3, - top_k=50, - top_p=0.95, - repetition_penalty=1.8, + max_new_tokens=200, + temperature=0.5, + top_p=0.9, + repetition_penalty=2.0, num_beams=3, - no_repeat_ngram_size=4, - early_stopping=True, - pad_token_id=tokenizer.eos_token_id + no_repeat_ngram_size=3 ) - answer = tokenizer.decode(outputs[0], skip_special_tokens=True) - answer = answer.replace(prompt, "").strip() + answer = self.tokenizer.decode(outputs[0], skip_special_tokens=True) + answer = answer.split("[KONTEKST]")[-1].strip() sources = set() - for match in re.finditer(r'(?i)art\.?\s*\d+\.?', answer): - article_ref = match.group(0).strip().rstrip('.') - for source in self.model.source_mapper.idx_to_source.values(): + for match in re.finditer(r'(?i)art\.?\s*\d+', answer): + article_ref = match.group(0).strip() + for idx, source in self.model.source_mapper.idx_to_source.items(): if article_ref.lower() in source.lower(): sources.add(source) - return { - "answer": answer, - "sources": list(sources) if sources else ["Opracowanie własne"] - } + return {"answer": answer, "sources": list(sources)} def main(): - # Inicjalizacja source_mapper = SourceMapper() model_name = "crumb/nano-mistral" tokenizer = AutoTokenizer.from_pretrained(model_name) @@ -271,70 +240,53 @@ def main(): data = prepare_dataset("files", catalog_path, source_mapper) if not data: - print("\nBrak danych do treningu! Sprawdź pliki w katalogu 'files' i diagnostykę powyżej.") + print("\nBrak danych do treningu!") return dataset = Dataset.from_list(data) - def tokenize_function(examples): - tokenized = tokenizer( + def tokenize(examples): + return tokenizer( examples["text"], truncation=True, padding="max_length", max_length=512, return_tensors="pt" ) - return { - "input_ids": tokenized["input_ids"][0], - "attention_mask": tokenized["attention_mask"][0], - "labels": tokenized["input_ids"][0].clone(), - "source_idx": examples["source_idx"] - } + + tokenized_dataset = dataset.map(tokenize, batched=True, batch_size=16) - tokenized_dataset = dataset.map(tokenize_function, batched=False) - - def custom_collate_fn(features): - return { - "input_ids": torch.stack([torch.tensor(f["input_ids"]) for f in features]), - "attention_mask": torch.stack([torch.tensor(f["attention_mask"]) for f in features]), - "labels": torch.stack([torch.tensor(f["labels"]) for f in features]), - "source_idx": torch.tensor([f["source_idx"] for f in features], dtype=torch.long) - } - - # Model - config = AutoModelForCausalLM.from_pretrained(model_name).config - model = CustomModel(model_name, config) - model.source_mapper = source_mapper # Dodanie mapowania źródeł do modelu - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - model.to(device) - - # Trening training_args = TrainingArguments( output_dir="./results", - num_train_epochs=5, + num_train_epochs=8, per_device_train_batch_size=4, - gradient_accumulation_steps=2, - learning_rate=1e-5, + gradient_accumulation_steps=8, + learning_rate=5e-6, weight_decay=0.01, warmup_ratio=0.1, fp16=torch.cuda.is_available(), - logging_steps=10, + logging_steps=50, save_strategy="epoch", - evaluation_strategy="steps", - eval_steps=500, + eval_strategy="no", report_to="none", remove_unused_columns=False ) + model = CustomModel(model_name, AutoModelForCausalLM.from_pretrained(model_name).config) + model.source_mapper = source_mapper + model.to("cuda" if torch.cuda.is_available() else "cpu") + trainer = CustomTrainer( model=model, args=training_args, train_dataset=tokenized_dataset, - data_collator=custom_collate_fn, tokenizer=tokenizer ) + print("\nRozpoczęcie treningu...") trainer.train() + + print("\nKońcowa ewaluacja...") trainer.evaluate() if __name__ == "__main__":