from neo4j import GraphDatabase
import random
import logging

# ==========================
# Neo4j connection settings
# ==========================

# NEO4J_URI = "bolt://172.164.240.161:7687"
NEO4J_USER = "neo4j"
NEO4J_PASSWORD = "@MahilaMoghadami@"

logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")


# ==========================
# Neo4j Explorer Class
# ==========================

class FoodMCExplorer:

    def __init__(self, uri, user, password):
        self.driver = GraphDatabase.driver(uri, auth=(user, password))

    def close(self):
        self.driver.close()

    # --------------------------------------------------
    # Step 1: Get AssignedMC of input FoodCode
    # --------------------------------------------------
    def get_assigned_mc(self, food_code):
        with self.driver.session() as session:
            result = session.run("""
                MATCH (f:Food {FoodCode: $food_code})
                RETURN f.AssignedMC AS assigned_mc
            """, food_code=food_code)

            record = result.single()
            if not record or record["assigned_mc"] is None:
                raise ValueError("AssignedMC not found for given FoodCode")

            return record["assigned_mc"]

    # --------------------------------------------------
    # Step 2: Get all foods with same AssignedMC
    # --------------------------------------------------
    def get_foods_by_assigned_mc(self, assigned_mc):
        with self.driver.session() as session:
            result = session.run("""
                MATCH (f:Food {AssignedMC: $assigned_mc})
                RETURN
                    f.FoodCode AS FoodCode,
                    f.FoodNamePersian AS FoodNamePersian,
                    f.Healthy_level AS Healthy_level,
                    f.Disease_Diabetes_Level AS Disease_Diabetes_Level,
                    f.Disease_Hypertension_Level AS Disease_Hypertension_Level,
                    f.Disease_CVD_Level AS Disease_CVD_Level,
                    f.Disease_Celiac_Level AS Disease_Celiac_Level,
                    f.Disease_IBS_Level AS Disease_IBS_Level,
                    f.Disease_Anemia_Level AS Disease_Anemia_Level,
                    f.Disease_MS_Level AS Disease_MS_Level,
                    f.Disease_Gout_Level AS Disease_Gout_Level,
                    f.Disease_HighCholesterol_Level AS Disease_HighCholesterol_Level,
                    f.Disease_Kidney_Level AS Disease_Kidney_Level,
                    f.Disease_IBD_Level AS Disease_IBD_Level,
                    f.Disease_LactoseIntolerance_Level AS Disease_LactoseIntolerance_Level,
                    f.Disease_NAFLD_Level AS Disease_NAFLD_Level,
                    f.Disease_Hypothyroidism_Level AS Disease_Hypothyroidism_Level,
                    f.Disease_PCOS_Level AS Disease_PCOS_Level
            """, assigned_mc=assigned_mc)

            return [dict(record) for record in result]


# ==========================
# FINAL_TAG logic
# ==========================

def compute_final_tag(food, disease_list):
    """
    FINAL_TAG = min(Healthy_level, selected Disease Levels)
    """

    values = []

    if food.get("Healthy_level") is not None:
        values.append(food["Healthy_level"])

    for disease in disease_list:
        key = f"Disease_{disease}_Level"
        if key in food and food[key] is not None:
            values.append(food[key])

    if not values:
        return None

    return min(values)


def split_and_shuffle_foods(foods, disease_list):
    two_finaltag = []
    one_finaltag = []

    for food in foods:
        final_tag = compute_final_tag(food, disease_list)
        food["FINAL_TAG"] = final_tag

        if final_tag == 2:
            two_finaltag.append(food)
        elif final_tag == 1:
            one_finaltag.append(food)

    random.shuffle(two_finaltag)
    random.shuffle(one_finaltag)

    final_list = two_finaltag + one_finaltag

    logging.info(f"TWO_FINALTAG count: {len(two_finaltag)}")
    logging.info(f"ONE_FINALTAG count: {len(one_finaltag)}")
    logging.info(f"FINAL_LIST count: {len(final_list)}")

    return two_finaltag, one_finaltag, final_list


# ==========================
# Main Execution
# ==========================

# if __name__ == "__main__":

#     # -------- Inputs --------
#     input_food_code = 90024

#     disease_list = [
#         "Diabetes",
#         # "Hypertension",
#         # "CVD",
#         # "IBS"
#     ]

#     explorer = FoodMCExplorer(
#         NEO4J_URI,
#         NEO4J_USER,
#         NEO4J_PASSWORD
#     )

#     try:
#         logging.info("Fetching AssignedMC...")
#         assigned_mc = explorer.get_assigned_mc(input_food_code)
#         logging.info(f"AssignedMC: {assigned_mc}")

#         logging.info("Fetching foods with same AssignedMC...")
#         foods = explorer.get_foods_by_assigned_mc(assigned_mc)
#         logging.info(f"Total foods found: {len(foods)}")

#         two_final, one_final, final_list = split_and_shuffle_foods(
#             foods,
#             disease_list
#         )

#         print("\n=== TWO_FINALTAG ===")
#         for f in two_final[:10]:
#             print(f"{f['FoodNamePersian']} | FINAL_TAG: {f['FINAL_TAG']} | FoodCode: {f['FoodCode']}")

#         print("\n=== ONE_FINALTAG ===")
#         for f in one_final[:10]:
#             print(f"{f['FoodNamePersian']} | FINAL_TAG: {f['FINAL_TAG']} | FoodCode: {f['FoodCode']}")


#         print("\n=== FINAL LIST (first 20) ===")
#         for f in final_list[:50]:
#             print(f"{f['FoodNamePersian']} | FINAL_TAG: {f['FINAL_TAG']} | FoodCode: {f['FoodCode']}")

#     finally:
#         explorer.close()
