reasoning_queries.py 2.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
  1. import argparse
  2. import os
  3. from datasets import load_dataset
  4. from openai import OpenAI
  5. from pqdm.processes import pqdm
  6. client = OpenAI()
  7. def reformulate_with_4o(query):
  8. query_template = """
  9. Given a query:
  10. 1. Repeat the query.
  11. 2. Identify the essential problem.
  12. 3. Think step by step to reason and describe what information could be relevant and helpful to address
  13. the questions in detail.
  14. 4. Draft an answer with as many thoughts as you have.
  15. Answer in the same language as the query.
  16. Query: {query}
  17. """
  18. prompt = query_template.format(query=query)
  19. completion = client.chat.completions.create(
  20. model="gpt-4.1",
  21. messages=[
  22. {"role": "developer", "content": "You are a helpful assistant."},
  23. {"role": "user", "content": prompt},
  24. ],
  25. )
  26. return completion.choices[0].message.content.strip()
  27. def process_dataset(dataset_name, query_column="query"):
  28. """
  29. Download dataset, reformulate queries using 4O, and reupload.
  30. Args:
  31. dataset_name: Name of the HuggingFace dataset
  32. query_column: Column containing queries
  33. """
  34. # Load dataset
  35. print(f"Loading dataset: {dataset_name}")
  36. dataset = load_dataset(dataset_name, "queries", split="test")
  37. # Import pqdm for parallel processing
  38. # Determine the number of cores to use
  39. n_jobs = os.cpu_count()
  40. print(f"Using {n_jobs} processes for parallel processing")
  41. # Prepare the dataset for processing
  42. queries = dataset[query_column]
  43. print(f"Processing {len(queries)} queries using 4O reformulation...")
  44. # Process queries in parallel
  45. reformulated_queries = pqdm(list(queries), reformulate_with_4o, n_jobs=n_jobs)
  46. print("Reformulation complete. Adding to dataset...")
  47. print(reformulated_queries[:5]) # Print first 5 reformulated queries for verification
  48. # Add the reformulated queries as a new column
  49. updated_dataset = dataset.add_column("gpt-4o-reasoning", reformulated_queries)
  50. print("Reformulation complete!")
  51. # Push to the Hugging Face Hub if auth_token is provided
  52. updated_dataset.push_to_hub(dataset_name, "queries", split="test")
  53. print("Upload complete!")
  54. return updated_dataset
  55. def main():
  56. parser = argparse.ArgumentParser(description="Reformulate queries in a dataset using 4O technique")
  57. parser.add_argument("--dataset", required=True, help="HuggingFace dataset name")
  58. parser.add_argument("--query_column", default="query", help="Column containing queries")
  59. args = parser.parse_args()
  60. process_dataset(args.dataset, args.query_column)
  61. if __name__ == "__main__":
  62. main()