From d82fe81f78c24b34b919681b904a3801a9d5ffc0 Mon Sep 17 00:00:00 2001
From: coolneng <akasroua@gmail.com>
Date: Thu, 20 May 2021 21:15:57 +0200
Subject: [PATCH] Change data representation in local search

---
 src/local_search.py | 43 ++++++++++++++++++++++++++++++++++++-------
 src/main.py         |  2 +-
 2 files changed, 37 insertions(+), 8 deletions(-)

diff --git a/src/local_search.py b/src/local_search.py
index 8237dee..a91210e 100644
--- a/src/local_search.py
+++ b/src/local_search.py
@@ -1,10 +1,39 @@
 from numpy.random import choice, seed
+from pandas import DataFrame
 
 
-def get_first_random_solution(m, data):
+def get_row_distance(source, destination, data):
+    row = data.query(
+        """(source == @source and destination == @destination) or \
+        (source == @destination and destination == @source)"""
+    )
+    return row["distance"].values[0]
+
+
+def compute_distance(element, solution, data):
+    accumulator = 0
+    distinct_elements = solution.query(f"point != {element}")
+    for _, item in distinct_elements.iterrows():
+        accumulator += get_row_distance(
+            source=element,
+            destination=item.point,
+            data=data,
+        )
+    return accumulator
+
+
+def get_first_random_solution(n, m, data):
+    solution = DataFrame(columns=["point", "distance"])
     seed(42)
-    random_indexes = choice(len(data.index), size=m, replace=False)
-    return data.loc[random_indexes]
+    solution["point"] = choice(n, size=m, replace=False)
+    solution["distance"] = solution["point"].apply(
+        func=compute_distance, solution=solution, data=data
+    )
+    return solution
+
+
+def evaluate_element_swap(solution, old_element, new_element, data):
+    pass
 
 
 def element_in_dataframe(solution, element):
@@ -22,14 +51,14 @@ def replace_worst_element(previous, data):
     while element_in_dataframe(solution=solution, element=random_element):
         random_element = data.sample().squeeze()
     solution.loc[worst_index] = random_element
-    return solution, worst_index
+    return solution
 
 
 def get_random_solution(previous, data):
     solution, worst_index = replace_worst_element(previous, data)
     previous_worst_distance = previous["distance"].loc[worst_index]
     while solution.distance.loc[worst_index] <= previous_worst_distance:
-        solution, _ = replace_worst_element(previous=solution, data=data)
+        solution = replace_worst_element(previous=solution, data=data)
     return solution
 
 
@@ -43,8 +72,8 @@ def explore_neighbourhood(element, data, max_iterations=100000):
     return neighbour
 
 
-def local_search(m, data):
-    first_solution = get_first_random_solution(m=m, data=data)
+def local_search(n, m, data):
+    first_solution = get_first_random_solution(n, m, data)
     best_solution = explore_neighbourhood(
         element=first_solution, data=data, max_iterations=100
     )
diff --git a/src/main.py b/src/main.py
index cf7f9f4..0f91cfd 100755
--- a/src/main.py
+++ b/src/main.py
@@ -10,7 +10,7 @@ def execute_algorithm(choice, n, m, data):
     if choice == "greedy":
         return greedy_algorithm(n, m, data)
     elif choice == "local":
-        return local_search(m, data)
+        return local_search(n, m, data)
     else:
         print("The valid algorithm choices are 'greedy' and 'local'")
         exit(1)