diff --git a/.DS_Store b/.DS_Store new file mode 100644 index 0000000..48b9a8d Binary files /dev/null and b/.DS_Store differ diff --git a/allegro.py b/allegro.py new file mode 100644 index 0000000..e0cac0f --- /dev/null +++ b/allegro.py @@ -0,0 +1,119 @@ +import os +import re +import torch +from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer, DataCollatorForLanguageModeling +from datasets import Dataset + +# Konfiguracja +os.environ["TOKENIZERS_PARALLELISM"] = "false" +MODEL_NAME = "allegro/herbert-base-cased" +SPECIAL_TOKENS = ["[CITATION_START]", "[CITATION_END]"] +TEXT_FILE_PATH = "./docs/kodekspracy.txt" # Zmień na właściwą ścieżkę + +def prepare_dataset_from_file(file_path): + with open(file_path, 'r', encoding='utf-8') as f: + text = f.read() + + articles = re.findall(r'Art\.\s*\d+[a-z]*\..*?(?=\s*Art\.\s*\d+[a-z]*\.|\Z)', text, flags=re.DOTALL) + + formatted_articles = [] + for article in articles: + article = ' '.join(article.strip().split()) + + art_match = re.match(r'Art\.\s*(\d+[a-z]*)\.?\s*(.*)', article, re.DOTALL) + if art_match: + art_number = art_match.group(1) + art_text = art_match.group(2) + + paragraphs = re.split(r'(§\s*\d+\.)', art_text) + if len(paragraphs) > 1: + formatted_paragraphs = [] + for i in range(1, len(paragraphs), 2): + para_num = paragraphs[i].strip() + para_text = paragraphs[i+1].strip() + formatted_paragraphs.append(f"{para_num} {para_text}") + formatted = f"[CITATION_START] Kodeks Pracy, Art. {art_number} [CITATION_END]\n" + "\n".join(formatted_paragraphs) + else: + formatted = f"[CITATION_START] Kodeks Pracy, Art. {art_number} [CITATION_END] {art_text}" + + formatted_articles.append({"text": formatted}) + + questions = [ + f"Zacytuj artykuł {art_number} Kodeksu pracy.", + f"Co mówi artykuł {art_number} Kodeksu pracy?", + f"Podaj treść artykułu {art_number} Kodeksu pracy." + ] + for question in questions: + formatted_articles.append({"text": f"{question}\n{formatted}"}) + + return formatted_articles + +def main(): + # Inicjalizacja tokenizera + tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) + tokenizer.pad_token = tokenizer.eos_token + tokenizer.add_special_tokens({"additional_special_tokens": SPECIAL_TOKENS}) + + print(f"Pad token: {tokenizer.pad_token}") + print(f"Pad token ID: {tokenizer.pad_token_id}") + + # Przygotowanie danych + data = prepare_dataset_from_file(TEXT_FILE_PATH) + dataset = Dataset.from_dict({"text": [d["text"] for d in data]}) + + # Tokenizacja + def tokenize_function(examples): + tokenized = tokenizer( + examples["text"], + truncation=True, + padding="max_length", + max_length=512, + return_tensors="pt" + ) + tokenized["labels"] = tokenized["input_ids"].clone() + return tokenized + + tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=dataset.column_names) + + # Model i data collator + model = AutoModelForCausalLM.from_pretrained(MODEL_NAME) + model.resize_token_embeddings(len(tokenizer)) + model.config.pad_token_id = tokenizer.pad_token_id + + data_collator = DataCollatorForLanguageModeling( + tokenizer=tokenizer, + mlm=False + ) + + # Konfiguracja treningu + training_args = TrainingArguments( + output_dir="./results", + num_train_epochs=32, + per_device_train_batch_size=2, + learning_rate=1e-5, + logging_steps=10, + weight_decay=0.01, + report_to="none", + save_strategy="steps", + save_steps=500, + evaluation_strategy="steps", + eval_steps=500, + load_best_model_at_end=True, + ) + + # Trainer + trainer = Trainer( + model=model, + args=training_args, + train_dataset=tokenized_dataset, + eval_dataset=tokenized_dataset, + data_collator=data_collator + ) + + print("Rozpoczęcie treningu...") + trainer.train() + trainer.save_model("./trained_model/allegro") + tokenizer.save_pretrained("./trained_model/allegro") + +if __name__ == "__main__": + main() diff --git a/catalog.json b/catalog.json new file mode 100644 index 0000000..128779e --- /dev/null +++ b/catalog.json @@ -0,0 +1,5 @@ +{ + "kodekspracy": "Kodeks Pracy", + "urlopproporcjonalny": "Rozporządzenie BHP", + "ustawaopanstwowejinspekcjipracy": "Ustawa o Państwowej inspekcji pracy" +} \ No newline at end of file diff --git a/files/kodekspracy.md b/docs/kodekspracy.txt similarity index 99% rename from files/kodekspracy.md rename to docs/kodekspracy.txt index 9179a1a..8b1e293 100644 --- a/files/kodekspracy.md +++ b/docs/kodekspracy.txt @@ -2,7 +2,7 @@ USTAWA z dnia 26 czerwca 1974 r. -Kodeks pracy1) +Kodeks pracy (Dz. U. z 2023 r. poz. 1465 oraz z 2024 r. poz. 878, 1222, 1871 i 1965) @@ -11,8 +11,6 @@ obowiązuje od dnia 1 stycznia 1975 r. historia od dnia 16 lutego 1998 r. -Preambuła (uchylona) - DZIAŁ PIERWSZY Przepisy ogólne diff --git a/files/urlopproporcjonalny.md b/docs/urlopproporcjonalny.txt similarity index 100% rename from files/urlopproporcjonalny.md rename to docs/urlopproporcjonalny.txt diff --git a/files/Ustawa o Państwowej inspekcji pracy.pdf b/docs/ustawaopanstwowejinspekcjipracy.pdf similarity index 100% rename from files/Ustawa o Państwowej inspekcji pracy.pdf rename to docs/ustawaopanstwowejinspekcjipracy.pdf diff --git a/file_catalog.json b/file_catalog.json deleted file mode 100644 index cd37ea0..0000000 --- a/file_catalog.json +++ /dev/null @@ -1,3 +0,0 @@ -{ - "kodekspracy.md": "Kodeks pracy" -} \ No newline at end of file diff --git a/gemma-faiss.py b/gemma-faiss.py new file mode 100644 index 0000000..4c46cad --- /dev/null +++ b/gemma-faiss.py @@ -0,0 +1,93 @@ +import os +os.environ["TOKENIZERS_PARALLELISM"] = "false" + +import faiss +import numpy as np +import ollama +import gradio as gr +import os +import argparse +from sentence_transformers import SentenceTransformer + +# === KONFIGURACJA === +model_name = "hse.ably.do:latest" # Nazwa modelu Ollama +faiss_index_path = "faiss_index.idx" # Plik indeksu FAISS +kodeks_file = "/home/ably.do/docs/kodekspracy.txt" # Plik z treścią kodeksu pracy +embedding_model = SentenceTransformer("all-MiniLM-L6-v2") # Model do embedowania tekstu + +# === KROK 1: WCZYTYWANIE KODEKSU PRACY === +def load_kodeks(filepath): + with open(filepath, "r", encoding="utf-8") as file: + content = file.read() + articles = content.split("\n\n") # Dzielimy na sekcje + return [article.strip() for article in articles if article.strip().startswith("Art.")] + +# === KROK 2: TWORZENIE INDEKSU FAISS === +def create_faiss_index(sections): + embeddings = embedding_model.encode(sections, convert_to_numpy=True) # Tworzenie wektorów + index = faiss.IndexFlatL2(embeddings.shape[1]) # Indeks FAISS + index.add(embeddings) # Dodanie wektorów do FAISS + faiss.write_index(index, faiss_index_path) # Zapis indeksu + return index, sections + +# === KROK 3: WYSZUKIWANIE NAJBLIŻSZEGO FRAGMENTU === +def search_faiss(query, index, sections, top_k=3): + query_vector = embedding_model.encode([query], convert_to_numpy=True) + _, idx = index.search(query_vector, top_k) # Szukamy więcej wyników + + results = [sections[i] for i in idx[0] if i < len(sections)] + return "\n\n".join(results) # Połącz kilka najlepszych fragmentów + +# === KROK 4: GENEROWANIE ODPOWIEDZI Z OLLAMA === +def generate_response(user_query): + if not os.path.exists(faiss_index_path): + return "Błąd: Indeks FAISS nie istnieje. Uruchom aplikację z opcją --rebuild-index." + + try: + index = faiss.read_index(faiss_index_path) + except Exception as e: + return f"Błąd ładowania FAISS: {str(e)}" + + sections = load_kodeks(kodeks_file) + best_match = search_faiss(user_query, index, sections) + + # 👀 DEBUG: Sprawdź, co zwraca FAISS + print(f"🔍 Najlepsze dopasowanie FAISS dla '{user_query}':\n{best_match}") + + prompt = f""" + Odpowiedz na pytanie na podstawie następującego tekstu: + + {best_match} + + Pytanie: {user_query} + Podaj dokładny tekst artykułu, jeśli go znajdziesz w treści powyżej. + """ + + response = ollama.chat(model=model_name, messages=[{"role": "user", "content": prompt}]) + + print(f"📝 Odpowiedź modelu:\n{response}") # 👀 DEBUG: Sprawdź odpowiedź Ollama + + return response.get("message", response.get("content", "Błąd: Nie udało się wygenerować odpowiedzi.")) + +# === KROK 5: INTERFEJS WEBOWY === +iface = gr.Interface( + fn=generate_response, + inputs=gr.Textbox(label="Zadaj pytanie o kodeks pracy"), + outputs=gr.Textbox(label="Odpowiedź"), + title="Asystent Kodeksu Pracy", + description="Wpisz pytanie, a system zwróci odpowiedni fragment kodeksu pracy." +) + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--rebuild-index", action="store_true", help="Odbudowanie indeksu FAISS") + args = parser.parse_args() + + if args.rebuild_index or not os.path.exists(faiss_index_path): + print("Tworzenie nowego indeksu FAISS...") + sections = load_kodeks(kodeks_file) + create_faiss_index(sections) + else: + print("Indeks FAISS już istnieje.") + + iface.launch(share=True) \ No newline at end of file diff --git a/gemma.py b/gemma.py new file mode 100644 index 0000000..11c7caf --- /dev/null +++ b/gemma.py @@ -0,0 +1,119 @@ +import os +os.environ["TOKENIZERS_PARALLELISM"] = "false" + +import torch +import faiss +import numpy as np +from sentence_transformers import SentenceTransformer +from datasets import Dataset +from peft import LoraConfig, get_peft_model +from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer, DataCollatorForLanguageModeling + +# 1️⃣ Inicjalizacja modelu do embeddingów +embed_model = SentenceTransformer("all-MiniLM-L6-v2") + +# 2️⃣ Dodanie dokumentów i embeddingów +def read_documents_from_file(file_path): + with open(file_path, 'r', encoding='utf-8') as file: + content = file.read() + articles = content.split('\n\n') + documents = [] + for article in articles: + if article.strip().startswith('Art.'): + documents.append(article.strip()) + return documents +#documents = [ +# "Jak założyć firmę w Polsce?", +# "Jak rozliczyć podatek VAT?", +# "Procedura składania reklamacji w e-sklepie.", +# "Jakie dokumenty są potrzebne do rejestracji działalności?" +#] +file_path = './docs/kodekspracy.txt' # Zmień na właściwą ścieżkę +documents = read_documents_from_file(file_path) +embeddings = embed_model.encode(documents) + +# 3️⃣ Inicjalizacja FAISS i dodanie wektorów +dim = embeddings.shape[1] +index = faiss.IndexFlatL2(dim) +index.add(np.array(embeddings, dtype=np.float32)) + +# 4️⃣ Przygotowanie danych treningowych +def create_training_data(): + data = { + "text": documents, + "embedding": embeddings.tolist() + } + return Dataset.from_dict(data) + +dataset = create_training_data() + +# Podział danych na treningowe i ewaluacyjne +split_dataset = dataset.train_test_split(test_size=0.25) +train_dataset = split_dataset["train"] +eval_dataset = split_dataset["test"] + +# 5️⃣ Ładowanie modelu Gemma 2B +device = "cuda" if torch.cuda.is_available() else "cpu" +model_name = "google/gemma-2-2b" +model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16).to(device) +tokenizer = AutoTokenizer.from_pretrained(model_name) + +# 6️⃣ Konfiguracja LoRA +lora_config = LoraConfig( + r=8, lora_alpha=32, lora_dropout=0.1, bias="none", task_type="CAUSAL_LM" +) +model = get_peft_model(model, lora_config) + +# 7️⃣ Tokenizacja danych +max_length = 384 + +def tokenize_function(examples): + return tokenizer( + examples["text"], + padding="max_length", + truncation=True, + max_length=max_length + ) + +tokenized_train = train_dataset.map(tokenize_function, batched=True) +tokenized_eval = eval_dataset.map(tokenize_function, batched=True) + +# 8️⃣ Parametry treningu +training_args = TrainingArguments( + output_dir="./results", + eval_strategy="steps", # Ewaluacja co określoną liczbę kroków + eval_steps=500, # Ewaluacja co 500 kroków + save_strategy="steps", # Zapis modelu co określoną liczbę kroków + save_steps=500, # Zapis modelu co 500 kroków + learning_rate=1e-5, + per_device_train_batch_size=2, + per_device_eval_batch_size=2, + num_train_epochs=16, + weight_decay=0.01, + load_best_model_at_end=True, # Wczytaj najlepszy model na końcu + metric_for_best_model="loss", # Kryterium wyboru najlepszego modelu + greater_is_better=False, # Niższy loss = lepszy model +) + +# 9️⃣ Data Collator +data_collator = DataCollatorForLanguageModeling( + tokenizer=tokenizer, + mlm=False +) + +# 🔟 Trening modelu +trainer = Trainer( + model=model, + args=training_args, + train_dataset=tokenized_train, + eval_dataset=tokenized_eval, # Dodany zestaw ewaluacyjny + data_collator=data_collator, +) + +trainer.train() + +# 1️⃣1️⃣ Zapis modelu +model.save_pretrained("./trained_model/gemma") +tokenizer.save_pretrained("./trained_model/gemma") + +print("✅ Model został wytrenowany i zapisany!") \ No newline at end of file diff --git a/gpt.py b/gpt.py new file mode 100644 index 0000000..dfa57e9 --- /dev/null +++ b/gpt.py @@ -0,0 +1,118 @@ +import os +import re +import torch +from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer, DataCollatorForLanguageModeling +from datasets import Dataset + +# Konfiguracja +os.environ["TOKENIZERS_PARALLELISM"] = "false" +MODEL_NAME = "gpt2-medium" +SPECIAL_TOKENS = ["[CITATION_START]", "[CITATION_END]"] +TEXT_FILE_PATH = "./docs/kodekspracy.txt" # Zmień na właściwą ścieżkę + +def prepare_dataset_from_file(file_path): + with open(file_path, 'r', encoding='utf-8') as f: + text = f.read() + + # Wydziel artykuły za pomocą wyrażenia regularnego + articles = re.findall(r'Art\.\s*\d+[a-z]*\..*?(?=\s*Art\.\s*\d+[a-z]*\.|\Z)', text, flags=re.DOTALL) + + formatted_articles = [] + for article in articles: + # Usuń zbędne białe znaki + article = ' '.join(article.strip().split()) + + # Wydziel numer artykułu i treść + art_match = re.match(r'Art\.\s*(\d+[a-z]*)\.?\s*(.*)', article, re.DOTALL) + if art_match: + art_number = art_match.group(1) + art_text = art_match.group(2) + + # Podziel na paragrafy, jeśli istnieją + paragraphs = re.split(r'(§\s*\d+\.)', art_text) + if len(paragraphs) > 1: + formatted_paragraphs = [] + for i in range(1, len(paragraphs), 2): + para_num = paragraphs[i].strip() + para_text = paragraphs[i+1].strip() + formatted_paragraphs.append(f"{para_num} {para_text}") + formatted = f"[CITATION_START] Kodeks Pracy, Art. {art_number} [CITATION_END]\n" + "\n".join(formatted_paragraphs) + else: + formatted = f"[CITATION_START] Kodeks Pracy, Art. {art_number} [CITATION_END] {art_text}" + + formatted_articles.append({"text": formatted}) + + # Dodaj przykłady pytań i odpowiedzi + questions = [ + f"Zacytuj artykuł {art_number} Kodeksu pracy.", + f"Co mówi artykuł {art_number} Kodeksu pracy?", + f"Podaj treść artykułu {art_number} Kodeksu pracy." + ] + for question in questions: + formatted_articles.append({"text": f"{question}\n{formatted}"}) + + return formatted_articles + + +def main(): + # Inicjalizacja tokenizera + tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) + tokenizer.add_special_tokens({"additional_special_tokens": SPECIAL_TOKENS}) + tokenizer.pad_token = tokenizer.eos_token + + # Przygotowanie danych + data = prepare_dataset_from_file(TEXT_FILE_PATH) + dataset = Dataset.from_dict({"text": [d["text"] for d in data]}) + + # Tokenizacja + def tokenize_function(examples): + tokenized = tokenizer( + examples["text"], + truncation=True, + padding="max_length", + max_length=1024, # Zwiększono dla dłuższych artykułów + return_tensors="pt" + ) + tokenized["labels"] = tokenized["input_ids"].clone() + return tokenized + + tokenized_dataset = dataset.map(tokenize_function, batched=True) + + # Model i data collator + model = AutoModelForCausalLM.from_pretrained(MODEL_NAME) + model.resize_token_embeddings(len(tokenizer), mean_resizing=False) + + data_collator = DataCollatorForLanguageModeling( + tokenizer=tokenizer, + mlm=False + ) + + # Konfiguracja treningu + training_args = TrainingArguments( + output_dir="./results", + num_train_epochs=32, # Zwiększono liczbę epok + per_device_train_batch_size=2, + learning_rate=1e-5, #precyzja uczenia + logging_steps=10, + weight_decay=0.01, + report_to="none", + save_strategy="no", + load_best_model_at_end=True, # Ładowanie najlepszego modelu na końcu + ) + + + # Trainer + trainer = Trainer( + model=model, + args=training_args, + train_dataset=tokenized_dataset, + data_collator=data_collator + ) + + print("Rozpoczęcie treningu...") + trainer.train() + trainer.save_model("./trained_model/gpt") + tokenizer.save_pretrained("./trained_model/gpt") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/hft.py b/hft.py index a3069f3..1f2a4e7 100644 --- a/hft.py +++ b/hft.py @@ -1,14 +1,14 @@ import os import torch import torch.nn as nn -from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer +from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer, DataCollatorForLanguageModeling from datasets import Dataset -from PIL import Image import re -import pytesseract -import docx2txt -import PyPDF2 import json +import PyPDF2 +import docx2txt +import pytesseract +from PIL import Image from collections import defaultdict from huggingface_hub import login @@ -34,127 +34,205 @@ class SourceMapper: return self.idx_to_source.get(idx, "Unknown") def load_file_catalog(catalog_path): - with open(catalog_path, 'r', encoding='utf-8') as file: - return json.load(file) + try: + with open(catalog_path, 'r', encoding='utf-8') as file: + return json.load(file) + except Exception as e: + print(f"Błąd wczytywania katalogu plików: {str(e)}") + return {} def identify_legal_document(filename, file_catalog): - return file_catalog.get(filename, "Opracowanie własne") + base_name = os.path.splitext(filename)[0].lower() + return file_catalog.get(base_name, "Opracowanie własne") def extract_text_from_file(file_path): - _, ext = os.path.splitext(file_path) - ext = ext.lower() - - if ext in ['.txt', '.md']: - with open(file_path, 'r', encoding='utf-8') as file: - return file.read() - elif ext == '.pdf': - text = "" - with open(file_path, 'rb') as file: - reader = PyPDF2.PdfReader(file) - for page in reader.pages: - text += page.extract_text() - return text - elif ext in ['.doc', '.docx']: - return docx2txt.process(file_path) - elif ext in ['.jpg', '.jpeg', '.png', '.bmp', '.tiff']: - return pytesseract.image_to_string(Image.open(file_path)) - else: + try: + _, ext = os.path.splitext(file_path) + ext = ext.lower() + + if ext in ['.txt', '.md']: + with open(file_path, 'r', encoding='utf-8') as file: + return file.read() + elif ext == '.pdf': + text = "" + try: + with open(file_path, 'rb') as file: + reader = PyPDF2.PdfReader(file) + for page in reader.pages: + text += page.extract_text() or "" + except Exception as e: + print(f"Błąd PDF: {str(e)}") + return text + elif ext in ['.doc', '.docx']: + return docx2txt.process(file_path) + elif ext in ['.jpg', '.jpeg', '.png', '.bmp', '.tiff']: + return pytesseract.image_to_string(Image.open(file_path)) + else: + print(f"Nieobsługiwany format pliku: {ext}") + return "" + except Exception as e: + print(f"Błąd ekstrakcji tekstu: {str(e)}") return "" def prepare_dataset(directory, catalog_path, source_mapper): file_catalog = load_file_catalog(catalog_path) data = [] + print(f"\n{'='*50}\nDIAGNOSTYKA DANYCH\n{'='*50}") + for root, _, files in os.walk(directory): for file in files: file_path = os.path.join(root, file) - text = extract_text_from_file(file_path) - if not text: + print(f"\nPrzetwarzanie pliku: {file_path}") + + try: + text = extract_text_from_file(file_path) + if not text.strip(): + print("Pominięto - brak tekstu") + continue + + print(f"Długość tekstu: {len(text)} znaków") + + doc_type = identify_legal_document(file, file_catalog) + print(f"Rozpoznany typ dokumentu: {doc_type}") + + if doc_type != "Opracowanie własne": + articles = re.split(r'(?i)(Art[\.\s]+\d+[\.\s]?)', text) + articles = [a.strip() for a in articles if a.strip()] + + print(f"Znaleziono {len(articles)} fragmentów") + + for i in range(0, len(articles)-1, 2): + article_number = articles[i] + article_content = articles[i+1] + + if len(article_content) < 50: + continue + + source = f"{doc_type}, {article_number}" + source_mapper.add_source(source) + data.append({ + "text": f"{article_number} {article_content}", + "source_idx": source_mapper.get_idx(source) + }) + else: + clean_text = re.sub(r'\s+', ' ', text).strip() + chunks = [clean_text[i:i+512] for i in range(0, len(clean_text), 512)] + chunks = [c for c in chunks if c.strip()] + + for chunk in chunks: + data.append({ + "text": chunk, + "source_idx": -1 + }) + print(f"Dodano {len(chunks)} chunków") + + except Exception as e: + print(f"Błąd podczas przetwarzania pliku: {str(e)}") continue - doc_type = identify_legal_document(file, file_catalog) - if doc_type != "Opracowanie własne": - articles = re.split(r'(Art\.\s+\d+[\.\s])', text) - for i in range(1, len(articles), 2): - article_number = articles[i].strip() - article_content = articles[i+1].strip() if i+1 < len(articles) else "" - source = f"{doc_type}, {article_number}" - source_mapper.add_source(source) - - data.append({ - "text": f"{article_number} {article_content}", - "source_idx": source_mapper.get_idx(source) - }) - else: - chunks = [text[i:i+512] for i in range(0, len(text), 512)] - for chunk in chunks: - data.append({ - "text": chunk, - "source_idx": -1 - }) + print(f"\nPodsumowanie przygotowania danych:") + print(f"Łączna liczba przykładów: {len(data)}") + if data: + print("Przykładowy wpis:") + print(json.dumps(data[0], indent=2, ensure_ascii=False)) + else: + print("BRAK DANYCH - sprawdź diagnostykę powyżej") + return data -def tokenize_function(examples): - tokenized = tokenizer( - examples["text"], - truncation=True, - padding="max_length", - max_length=512, - return_tensors="pt" - ) - tokenized["labels"] = tokenized["input_ids"].clone() - tokenized["source_idx"] = examples["source_idx"] - return tokenized - -def custom_collate_fn(batch): - input_ids = torch.stack([torch.tensor(b["input_ids"]) for b in batch]) - attention_mask = torch.stack([torch.tensor(b["attention_mask"]) for b in batch]) - labels = torch.stack([torch.tensor(b["labels"]) for b in batch]) - source_idx = torch.tensor([b.get("source_idx", -1) for b in batch], dtype=torch.long) - return { - "input_ids": input_ids, - "attention_mask": attention_mask, - "labels": labels, - "source_idx": source_idx - } - class CustomModel(nn.Module): def __init__(self, model_name, config): super().__init__() self.base_model = AutoModelForCausalLM.from_pretrained(model_name, config=config) - self.source_embedding = nn.Embedding(1000, config.hidden_size, padding_idx=-1) + self.source_embedding = nn.Embedding(10000, config.hidden_size, padding_idx=-1) + for param in self.base_model.parameters(): + param.requires_grad = False + for param in self.base_model.get_output_embeddings().parameters(): + param.requires_grad = True + def forward(self, input_ids=None, attention_mask=None, labels=None, source_idx=None, **kwargs): if source_idx is not None: - source_idx = torch.clamp(source_idx, 0, self.source_embedding.num_embeddings-1) - source_embeds = self.source_embedding(source_idx).unsqueeze(1).expand(-1, input_ids.size(1), -1) + valid_indices = torch.clamp(source_idx, 0, self.source_embedding.num_embeddings-1) + source_embeds = self.source_embedding(valid_indices).unsqueeze(1) inputs_embeds = self.base_model.get_input_embeddings()(input_ids) + source_embeds - return self.base_model(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels, **kwargs) - return self.base_model(input_ids=input_ids, attention_mask=attention_mask, labels=labels, **kwargs) + return self.base_model( + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + labels=labels, + **kwargs + ) + return self.base_model( + input_ids=input_ids, + attention_mask=attention_mask, + labels=labels, + **kwargs + ) def generate(self, *args, **kwargs): return self.base_model.generate(*args, **kwargs) +class CustomDataCollator(DataCollatorForLanguageModeling): + def torch_call(self, examples): + # Przetwórz podstawowe pola + input_ids = torch.stack([torch.tensor(ex["input_ids"]) for ex in examples]) + attention_mask = torch.stack([torch.tensor(ex["attention_mask"]) for ex in examples]) + labels = torch.stack([torch.tensor(ex["labels"]) for ex in examples]) + + batch = { + "input_ids": input_ids, + "attention_mask": attention_mask, + "labels": labels + } + + # Dodaj source_idx jeśli istnieje + if "source_idx" in examples[0]: + source_idx = torch.stack([torch.tensor(ex["source_idx"]) for ex in examples]) + batch["source_idx"] = source_idx + + return batch + def main(): - # Inicjalizacja source_mapper = SourceMapper() model_name = "crumb/nano-mistral" tokenizer = AutoTokenizer.from_pretrained(model_name) tokenizer.pad_token = tokenizer.eos_token # Przygotowanie danych - catalog_path = "file_catalog.json" - data = prepare_dataset("files", catalog_path, source_mapper) - dataset = Dataset.from_list(data) - tokenized_dataset = dataset.map(tokenize_function, batched=True, batch_size=8) + catalog_path = "catalog.json" + data = prepare_dataset("docs", catalog_path, source_mapper) + + if not data: + print("\nBrak danych do treningu!") + return - # Model - config = AutoModelForCausalLM.from_pretrained(model_name).config - model = CustomModel(model_name, config) + #dataset = Dataset.from_list(data) + dataset = Dataset.from_dict({k: [d[k] for d in data] for k in data[0]}) + + + def tokenize_function(examples): + tokenized = tokenizer( + examples["text"], + truncation=True, + padding="max_length", + max_length=512, + return_tensors="pt" + ) + return { + "input_ids": tokenized["input_ids"].squeeze(), + "attention_mask": tokenized["attention_mask"].squeeze(), + "labels": tokenized["input_ids"].squeeze().clone(), + "source_idx": examples["source_idx"] # Dodano bez konwersji do tensora + } + + tokenized_dataset = dataset.map(tokenize_function, batched=True, batch_size=16) + + model = CustomModel(model_name, AutoModelForCausalLM.from_pretrained(model_name).config) + model.source_mapper = source_mapper device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) - # Trening training_args = TrainingArguments( output_dir="./results", num_train_epochs=3, @@ -162,66 +240,22 @@ def main(): gradient_accumulation_steps=4, learning_rate=2e-5, fp16=torch.cuda.is_available(), - logging_steps=1, + logging_steps=10, save_strategy="steps", save_steps=1000, - report_to="none" + report_to="none", + remove_unused_columns=False ) trainer = Trainer( model=model, args=training_args, train_dataset=tokenized_dataset, - data_collator=custom_collate_fn, + data_collator=CustomDataCollator(tokenizer=tokenizer, mlm=False) ) - print("Rozpoczęcie treningu...") + + print("\nRozpoczęcie treningu...") trainer.train() - # Testowanie - def generate_answer(question): - inputs = tokenizer(question, return_tensors="pt").to(device) - - outputs = model.generate( - **inputs, - max_new_tokens=200, - temperature=0.7, - top_p=0.9, - do_sample=True, - repetition_penalty=1.2, - no_repeat_ngram_size=2, - pad_token_id=tokenizer.eos_token_id - ) - - answer = tokenizer.decode(outputs[0], skip_special_tokens=True) - answer = answer.replace(question, "").strip() - - sources = set() - for match in re.finditer(r'Art\.\s+\d+', answer): - article_ref = match.group(0).strip() - for idx, source in source_mapper.idx_to_source.items(): - if article_ref in source: - sources.add(source) - - return { - "question": question, - "answer": answer, - "sources": list(sources) if sources else ["Opracowanie własne"] - } - - # Przykładowe testy - test_questions = [ - "Jakie są zasady udzielania urlopu wypoczynkowego?", - "Co mówi art. 154 kodeksu pracy?", - "Jakie są obowiązki pracodawcy w zakresie BHP?" - ] - - print("\n" + "="*50 + "\nWYNIKI TESTOW\n" + "="*50) - for question in test_questions: - result = generate_answer(question) - print(f"\nPYTANIE: {result['question']}") - print(f"ODPOWIEDŹ: {result['answer'][:500]}") - print(f"ŹRÓDŁA: {', '.join(result['sources'])}") - print("-"*80) - if __name__ == "__main__": main() \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..cfc1745 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,8 @@ +torch>=2.0.1 +transformers>=4.30.2 +datasets>=2.13.1 +Pillow>=9.4.0 +pytesseract>=0.3.10 +python-docx>=0.8.11 +PyPDF2>=3.0.1 +huggingface-hub>=0.16.4 \ No newline at end of file diff --git a/test.py b/test.py new file mode 100644 index 0000000..1b77e8b --- /dev/null +++ b/test.py @@ -0,0 +1,22 @@ +from transformers import AutoModelForCausalLM, AutoTokenizer + +model_path = "./trained_model/gpt" +model = AutoModelForCausalLM.from_pretrained(model_path) +tokenizer = AutoTokenizer.from_pretrained(model_path) +tokenizer.pad_token = tokenizer.eos_token +model.config.pad_token_id = tokenizer.eos_token_id + +def generate_response(prompt, max_length=1000): + inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True) + outputs = model.generate( + inputs.input_ids, + attention_mask=inputs.attention_mask, + pad_token_id=tokenizer.pad_token_id, + max_length=100 + ) + response = tokenizer.decode(outputs[0], skip_special_tokens=True) + return response + +prompt = "Zacytuj paragraf pierwszy artykułu 154 Kodeksu pracy." +response = generate_response(prompt) +print(response) \ No newline at end of file