import os
import re
import time
import httpx
import structlog
import pandas as pd
from openai import OpenAI
from tools import tool_map
from dotenv import load_dotenv
from prompt import system_prompt
from eval_tool import eval_factor

load_dotenv()

logger = structlog.get_logger()

class LLM:
    def __init__(self, system_prompt: str):
        """
        :params system_prompt: 系统提示词
        """
        self.client = OpenAI(
            api_key=os.environ.get("OPENAI_API_KEY"), 
            base_url=os.environ.get("OPENAI_API_BASE_URL"), 
        ).chat.completions.create

        self.message = [
            {"role": "system", "content": system_prompt}
        ]
    
    def __call__(self, user_prompt: str):
        """
        :params user_prompt: 用户提示词
        """
        self.message.append(
            {"role": "user", "content": user_prompt}
        )

        response = self.client(
            model="deepseek-chat", 
            n=1, 
            messages=self.message
        ).choices[0].message.content

        self.message.append(
            {"role": "assistant", "content": response}
        )

        return response

class factor_agent:
    def __init__(self, max_iteration=10):
        """
        :params max_iteration: 最大循环次数
        """
        self.max_iteration = max_iteration
        self.model = LLM(system_prompt=system_prompt)

        self.observation_table = {}  # {表名: 表描述}
        self.table_map = {}          # 表名和表格映射关系
        self.factor_info = {}        # {因子名: 因子绩效解释}
    
    def _factor_to_dateframe(self):
        strings = list(self.factor_info.values())
        data = []
        # 遍历每个字符串并解析
        for string in strings:
            # 使用正则表达式提取因子名称
            name = re.search(r'因子\s+(.*?)\s+的', string).group(1)
            # 使用正则表达式提取数值
            ic = float(re.search(r'IC 为\s+([-]?\d+\.\d+)', string).group(1))
            cum_return = float(re.search(r'因子累计收益为\s+([-]?\d+\.\d+)', string).group(1))
            sharp = float(re.search(r'夏普比例为\s+([-]?\d+\.\d+)', string).group(1))
            volatility = float(re.search(r'年化波动率为\s+([-]?\d+\.\d+)', string).group(1))
            maxdrawback = float(re.search(r'最大回撤为\s+([-]?\d+\.\d+)', string).group(1))
            yearly_return = float(re.search(r'年化收益为\s+([-]?\d+\.\d+)', string).group(1))
            
            # 将解析后的数据添加到列表中
            data.append([name, ic, cum_return, yearly_return, sharp, volatility, maxdrawback])
        return pd.DataFrame(data, columns=['name', 'IC', 'cum_return', 'yearly_return', 'sharp', 'volatility', 'maxdrawback'])
    
    def __call__(self, user_prompt):
        c = 0
        while True:
            c+=1
            response = self.model(user_prompt)
            print(response)


            # response = response[response.find('```json'):]
            response = eval(response)
            
            thought = response['Thought']
            logger.info(f"Thought : {thought}")

            if len(response.keys()) == 1:
                return

            action_name = response['Action']
            func = tool_map[action_name] if action_name != "eval_factor" else eval_factor
            params = response['Action_input']

            table_name = f'table_{c}'

            if action_name == 'get_data':
                # 获取数据
                data = func(params['feature_name'], params['start_date'], params['end_date'])
                self.table_map[table_name] = data

                # 抽象一个观测值
                columns = ', '.join(list(data.columns))
                observation = f"已使用 get_data 抽取表格 {table_name} 数据, 表格包含字段 {columns}"

                # 添加表格描述
                self.observation_table[table_name] = f'该表包含字段 {columns}'
            
            if action_name == 'merge_data':
                # 合并数据
                name_1 = params['data_1']
                name_2 = params['data_2']
                data = func(self.table_map[name_1], self.table_map[name_2], params['how'])
                self.table_map[table_name] = data

                # 抽象一个观察值
                columns = ', '.join(list(data.columns))
                observation = f"已使用 merge_data 将表格 {name_1} 与 {name_2} 拼接成一张新表 {table_name}, 表格包含字段 {columns}"

                # 添加表格描述
                self.observation_table[table_name] = f'该表包含字段 {columns}'
            
            if action_name == 'data_filter':
                # 筛选数据
                name = params['data']
                threshold = params['threshold']
                feature_name = params['feature_name']
                data = func(self.table_map[name], params['feature_name'], params['direction'], threshold)
                self.table_map[table_name] = data

                # 抽象一个观察值
                columns = ', '.join(list(data.columns))
                if params['direction'] == '>':
                    symbol = '大于'
                elif params['direction'] == '=':
                    symbol = '等于'
                elif params['direction'] == '<':
                    symbol = '小于'
                observation = f"已使用 data_filter 将表格 {name} 中的字段 {feature_name} 作了 {symbol} {threshold} 的筛选并保存成一张新表 {table_name}, 表格包含字段 {columns}"

                # 添加表格描述
                self.observation_table[table_name] = f'该表包含字段 {columns}'
    
            if action_name in ['m_avg', 'm_std', 'm_max', 'm_min', 'm_pct_rank']:
                # 时序单参数算子
                name = params['data']
                feature_name = params['feature_name']
                N = params['N']
                data = func(self.table_map[name], params['feature_name'], params['N'])
                self.table_map[table_name] = data

                # 抽象一个观察值
                columns = ', '.join(list(data.columns))
                observation = f'已使用 {action_name} 将表格 {name} 中的字段 {feature_name} 作了 {N} 日时序运算, 并保存到了表 {table_name} 中, 该表包含字段 {columns}'
            
                # 添加表格描述
                self.observation_table[table_name] = f'该表包含字段 {columns}'

            if action_name in ['c_avg', 'c_std', 'c_z_normalize', 'c_pct_rank', 'abs_func']:
                # 截面单参数算子
                name = params['data']
                feature_name = params['feature_name']
                data = func(self.table_map[name], params['feature_name'])
                self.table_map[table_name] = data

                # 抽象一个观察值
                columns = ', '.join(list(data.columns))
                observation = f'已使用 {action_name} 将表格 {name} 中的字段 {feature_name} 作了截面(或绝对值)运算, 并保存到了表 {table_name} 中, 该表包含字段 {columns}'
            
                # 添加表格描述
                self.observation_table[table_name] = f'该表包含字段 {columns}'

            if action_name in ['m_corr']:
                # 时序两参数算子
                name = params['data']
                feature_name_1 = params['feature_name_1']
                feature_name_2 = params['feature_name_2']
                N = params['N']
                data = func(self.table_map[name], feature_name_1, feature_name_2, N)
                self.table_map[table_name] = data

                # 抽象一个观察值
                columns = ', '.join(list(data.columns))
                observation = f'已使用 {action_name} 将表格 {name} 中的字段 {feature_name_1} 和 {feature_name_2} 作了 {N} 日时序运算, 并保存到了表 {table_name} 中, 该表包含字段 {columns}'
            
                # 添加表格描述
                self.observation_table[table_name] = f'该表包含字段 {columns}'

            if action_name in ['add', 'divide', 'times', 'sub']:
                # 双参数算子
                name = params['data']
                feature_name_1 = params['feature_name_1']
                feature_name_2 = params['feature_name_2']
                data = func(self.table_map[name], feature_name_1, feature_name_2)
                self.table_map[table_name] = data

                # 抽象一个观察值
                columns = ', '.join(list(data.columns))
                observation = f'已使用 {action_name} 将表格 {name} 中的字段 {feature_name_1} 和 {feature_name_2} 作了运算, 并保存到了表 {table_name} 中, 该表包含字段 {columns}'

                # 添加表格描述
                self.observation_table[table_name] = f'该表包含字段 {columns}'

            if action_name == 'eval_factor':
                # 评估器
                start_date = params['start_date']
                end_date = params['end_date']
                feature_name = params['feature_name']
                fac_info = func(start_date, end_date, params['feature_name'])
                self.factor_info[feature_name] = fac_info

                # 抽象一个观察值
                observation = f'已使用 {action_name} 作因子分析: {fac_info}'

            logger.info(f"Observation : {c}. {observation}")
            t = str(self.observation_table)
            t2 = str(self.factor_info)
            observation += f'\n 从开始到现在已有的表格为: {t} \n 所有的因子绩效情况为: {t2}'
            observation += '\n 请继续进行挖掘任务'
            user_prompt = observation

            time.sleep(0.5)

            if self.max_iteration <= c:
                self.factor_info = self._factor_to_dateframe()
                return


[
    {"role": "system", "content": "你是一个精通量化投资的量化分析师"}, 
    {"role": "user", "content": "如何学习量化投资相关知识"}, 
    {"role": "assistant", "content": "你应该**********"}
]
