599 lines
		
	
	
		
			24 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			599 lines
		
	
	
		
			24 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
from collections import defaultdict
 | 
						||
import time
 | 
						||
import numpy as np
 | 
						||
import copy
 | 
						||
import argparse
 | 
						||
import json
 | 
						||
from scipy.optimize import curve_fit
 | 
						||
import matplotlib.pyplot as plt
 | 
						||
import numpy as np
 | 
						||
import time
 | 
						||
from matplotlib import font_manager
 | 
						||
import seaborn as sns
 | 
						||
import os
 | 
						||
import pandas as pd
 | 
						||
 | 
						||
def find_max_indices_numpy(L_dict):    
 | 
						||
    keys_arr = np.array(list(L_dict.keys()))
 | 
						||
    values_arr = np.array(list(L_dict.values()))
 | 
						||
    max_val_from_dict = max(L_dict.values()) # 或者 np.max(values_arr)
 | 
						||
 | 
						||
    indices = np.where(values_arr == max_val_from_dict)[0]
 | 
						||
    max_keys_np_way = keys_arr[indices]
 | 
						||
    return max_keys_np_way, max_val_from_dict
 | 
						||
 | 
						||
def R(event_dict):
 | 
						||
    mean = 0
 | 
						||
    for _,v in event_dict.items():
 | 
						||
        mean += -v
 | 
						||
    mean /= len(event_dict)
 | 
						||
    return mean
 | 
						||
    
 | 
						||
 | 
						||
def select_social_sensors(R, L_dict, user_event_dict, event_user_dict, b):
 | 
						||
    """
 | 
						||
    Parameters:
 | 
						||
    - R: reward function, R(A) returns a numeric value
 | 
						||
    - L_dict:  dict of node -> integer (event num)
 | 
						||
    - user_event_dict: user -> event -> active_time
 | 
						||
    - event_user_dict: event -> user_list
 | 
						||
    - b: budget (numeric)
 | 
						||
 | 
						||
    Returns:
 | 
						||
    - A: selected set of social sensors
 | 
						||
    """
 | 
						||
    print("Select social sensors begin!")
 | 
						||
    user_event_dict_ = copy.deepcopy(user_event_dict)
 | 
						||
    L_dict_ = copy.deepcopy(L_dict)
 | 
						||
    
 | 
						||
    V = list(L_dict_.keys())
 | 
						||
    A = []
 | 
						||
    f = lambda A: len(A)
 | 
						||
    all_cas = 0
 | 
						||
    while any(s not in A for s in V) and f(A) < b:
 | 
						||
        
 | 
						||
        indices_np, max_val_from_dict = find_max_indices_numpy(L_dict_)
 | 
						||
        if max_val_from_dict == 0:
 | 
						||
            print("no extra node in cas!!!!!")
 | 
						||
            break
 | 
						||
        TAR_set = set(indices_np)
 | 
						||
 | 
						||
        delta = {}
 | 
						||
        cur = {}
 | 
						||
        for s in TAR_set:
 | 
						||
            delta[s] = float('inf')
 | 
						||
            cur[s] = False
 | 
						||
 | 
						||
        c_star = []
 | 
						||
        while True:
 | 
						||
            s_star = max(delta, key=delta.get)
 | 
						||
            c_star = user_event_dict_[s_star]
 | 
						||
            if cur[s_star] == True:
 | 
						||
                if delta[s_star] >= args.anchor_screen:
 | 
						||
                    A.append(s_star)
 | 
						||
                break
 | 
						||
            else: 
 | 
						||
                delta[s_star] = R(c_star)
 | 
						||
                cur[s_star] = True
 | 
						||
 | 
						||
        if delta[s_star] < args.anchor_screen:
 | 
						||
            L_dict_[s_star] = 0
 | 
						||
            continue
 | 
						||
 | 
						||
        all_cas += len(c_star)
 | 
						||
        for cas_id in list(c_star.keys()):
 | 
						||
            uc = event_user_dict[cas_id]
 | 
						||
            for v in uc:
 | 
						||
                if v in L_dict_.keys():
 | 
						||
                    L_dict_[v] -= 1              
 | 
						||
                _ = user_event_dict_[v].pop(cas_id)
 | 
						||
        print(f"Add a social sensor, sensors num is {len(A)}")
 | 
						||
        print(f"Anchor id: {s_star}")
 | 
						||
        print(f"all_cas is {all_cas}")
 | 
						||
        print(f"TAR_set size: {len(TAR_set)}")
 | 
						||
        if all_cas >= len(event_user_dict):
 | 
						||
            user_event_dict_ = copy.deepcopy(user_event_dict)
 | 
						||
            L_dict_ = copy.deepcopy(L_dict)
 | 
						||
            for s in A:
 | 
						||
                L_dict_[s] = 0
 | 
						||
            all_cas = 0
 | 
						||
            print(f"All cas has been perception, add extra node")
 | 
						||
            # break
 | 
						||
        
 | 
						||
    
 | 
						||
    print(f"Select social sensors finish! Get social sensors, num: {len(A)}")
 | 
						||
    return A, all_cas
 | 
						||
 | 
						||
def handle_cas(args):
 | 
						||
    filename, result_path, obs_time = args.input_file, args.result_path, args.obs_time
 | 
						||
    print("Handle cascade begin!")
 | 
						||
    user_event_dict = defaultdict(dict)
 | 
						||
    event_user_dict = defaultdict(list)
 | 
						||
    cascades_total = 0
 | 
						||
    
 | 
						||
    # t_min_cas = -1
 | 
						||
    t_min_all = args.t_min
 | 
						||
 | 
						||
    with open(filename) as file:
 | 
						||
        for line in file:
 | 
						||
 | 
						||
            cascades_total += 1
 | 
						||
            if cascades_total > args.obs_cas_num and args.obs_cas_num >0:
 | 
						||
                break
 | 
						||
            parts = line.split(',')
 | 
						||
            cascade_id = parts[0]
 | 
						||
            activation_times = {}
 | 
						||
 | 
						||
            paths = parts[1:]
 | 
						||
 | 
						||
            t_max = 0
 | 
						||
            t_min = float('inf')
 | 
						||
            for p in paths:
 | 
						||
                # observed adoption/participant
 | 
						||
                n_t = p.split(':')
 | 
						||
                if len(n_t) < 2:
 | 
						||
                    continue
 | 
						||
                nodes = n_t[0].split('/')
 | 
						||
                time_now = int(n_t[1])
 | 
						||
                if time_now > t_max:
 | 
						||
                    t_max = time_now
 | 
						||
                if time_now < t_min:
 | 
						||
                    t_min = time_now
 | 
						||
                node = nodes[-1]
 | 
						||
                node_id = node
 | 
						||
                
 | 
						||
                if time_now > obs_time and obs_time != -1:
 | 
						||
                    continue
 | 
						||
 | 
						||
                if node_id in activation_times.keys():
 | 
						||
                    activation_times[node_id] = min(time_now, activation_times[node_id])
 | 
						||
                else:
 | 
						||
                    activation_times[node_id] = time_now
 | 
						||
            
 | 
						||
            print(f"cascade_id:{cascade_id}, t_min:{t_min}")
 | 
						||
            
 | 
						||
            if len(activation_times) < args.cas_min_length:
 | 
						||
                print(f"Too short Cas! Cas_id:{cascade_id}")
 | 
						||
                continue
 | 
						||
            if t_min > args.t_max or t_min < args.t_min:
 | 
						||
                print(f"Cas is not in proper time! Cas_id:{cascade_id}")
 | 
						||
                continue
 | 
						||
            
 | 
						||
            for k,v in activation_times.items():
 | 
						||
                event_user_dict[cascade_id].append(k)  
 | 
						||
                if t_max > t_min:
 | 
						||
                    user_event_dict[k][cascade_id] = (v-t_min)/(t_max-t_min)
 | 
						||
                    
 | 
						||
    # print(f"t_min_cas: {t_min_cas}")           
 | 
						||
    L_dict = {}
 | 
						||
 | 
						||
    for k,v in user_event_dict.items():
 | 
						||
        L_dict[k] = len(v.keys())
 | 
						||
    
 | 
						||
    print(f"Handle cascade file finish! Users num is {len(L_dict)}")
 | 
						||
    os.makedirs(result_path, exist_ok=True)
 | 
						||
    json_file = os.path.join(result_path, f'user_event_{args.obs_cas_num}.json')
 | 
						||
    try:
 | 
						||
        with open(json_file, 'w', encoding='utf-8') as f:
 | 
						||
            json.dump(user_event_dict, f, indent=4, ensure_ascii=False)
 | 
						||
        print(f"user_event_dict已成功保存到 JSON 文件: {json_file}")
 | 
						||
    except TypeError as e:
 | 
						||
        print(f"错误: 字典中可能包含无法序列化为 JSON 的数据类型。错误: {e}")
 | 
						||
    except IOError:
 | 
						||
        print(f"错误: 无法打开或写入文件 '{json_file}'。")
 | 
						||
    except Exception as e:
 | 
						||
        print(f"保存 JSON 文件时发生未知错误: {e}")
 | 
						||
        
 | 
						||
    return L_dict, user_event_dict, event_user_dict
 | 
						||
 | 
						||
def generate_anchors(args):
 | 
						||
    input_file, result_path, anchor_budget, obs_time = args.input_file, args.result_path, args.anchor_budget, args.obs_time
 | 
						||
    
 | 
						||
    L_dict, user_event_dict, event_user_dict = handle_cas(args)
 | 
						||
    
 | 
						||
    num_nodes = len(L_dict.keys())
 | 
						||
    
 | 
						||
    anchor_num = 0
 | 
						||
    
 | 
						||
    if anchor_budget > 0.02:
 | 
						||
        max_anchor_num = int(num_nodes*0.02)
 | 
						||
        print(f"Max anchor num is {max_anchor_num}, anchor_budget is set to {max_anchor_num}")
 | 
						||
        anchor_num = max_anchor_num
 | 
						||
    else:
 | 
						||
        anchor_num = int(num_nodes*anchor_budget)
 | 
						||
    
 | 
						||
    if args.anchor_num != -1:
 | 
						||
        anchor_num = args.anchor_num
 | 
						||
    print(f"anchor_num is set to {anchor_num}")
 | 
						||
    # anchor_num = 10
 | 
						||
    
 | 
						||
    A, all_cas = select_social_sensors(R, L_dict, user_event_dict, event_user_dict, anchor_num)
 | 
						||
    os.makedirs(result_path, exist_ok=True)
 | 
						||
    output_file = os.path.join(result_path, f'anchors_{args.anchor_budget}.txt')
 | 
						||
    with open(output_file, 'w') as file:
 | 
						||
        for item in A:
 | 
						||
            file.write(f"{item}\n")
 | 
						||
    
 | 
						||
    return A, all_cas, user_event_dict
 | 
						||
 | 
						||
# Logistic 函数定义
 | 
						||
def logistic(t, K, r, t0):
 | 
						||
    return K / (1 + np.exp(-r * (t - t0)))
 | 
						||
 | 
						||
def find_inflection_logistic(t_list, c_list):
 | 
						||
    t_array = np.array(t_list)
 | 
						||
    c_array = np.array(c_list)
 | 
						||
 | 
						||
    # 时间归一化到 [0, 1]
 | 
						||
    t_min = t_array.min()
 | 
						||
    t_max = t_array.max()
 | 
						||
    
 | 
						||
    if len(set(c_array)) < 3 or max(c_array) <= 1 or t_max - t_min == 0:
 | 
						||
        return None
 | 
						||
    t_array_normalized = (t_array - t_min) / (t_max - t_min)
 | 
						||
 | 
						||
    try:
 | 
						||
        p0 = [max(c_array) if max(c_array) > 0 else 1.0, 1.0, 0.5] # 确保 K 初始值不为0,如果c_array全是0
 | 
						||
        
 | 
						||
        # 检查 c_array 是否有效
 | 
						||
        if max(c_array) <= 0: # 如果c_array都是0或负数,也无法拟合S曲线
 | 
						||
             print("警告: 累计值 (c_list) 没有正增长,无法拟合S曲线。")
 | 
						||
             return None
 | 
						||
 | 
						||
        # 调整 bounds,如果 K 的下限是0,max(c_array)也可能是0,可能导致问题
 | 
						||
        # bounds = ([0, 0, 0], [np.inf, 10, 1]) # K r t0
 | 
						||
        # K > 0 (所以下限设为很小的正数,或基于数据)
 | 
						||
        # r > 0 (增长率)
 | 
						||
        # t0 在 [0, 1] 范围内
 | 
						||
        bounds = ([1e-6, 1e-6, 0], [np.inf, 10, 1]) # K, r 应该为正
 | 
						||
 | 
						||
        popt, _ = curve_fit(logistic, t_array_normalized, c_array, p0=p0, bounds=bounds, maxfev=10000)
 | 
						||
        K, r, t0 = popt
 | 
						||
        return t0
 | 
						||
    except (RuntimeError, ValueError):
 | 
						||
        print(f"警告: 曲线拟合失败或参数无效")
 | 
						||
        return None
 | 
						||
 | 
						||
# 获取所有事件的爆发时间
 | 
						||
def get_t0(args):
 | 
						||
    input_file, result_path = args.input_file, args.result_path
 | 
						||
    
 | 
						||
    print("To get All Cas' T0")
 | 
						||
    os.makedirs(result_path, exist_ok=True)
 | 
						||
    json_file = os.path.join(result_path, 'prepare.json')
 | 
						||
    prepare = {}
 | 
						||
    if os.path.exists(json_file):
 | 
						||
        print(f"文件 {json_file} 存在。")
 | 
						||
        with open(json_file, 'r', encoding='utf-8') as f:
 | 
						||
            prepare = json.load(f)
 | 
						||
        return prepare
 | 
						||
    t0_dict = {}
 | 
						||
    num_time_points = 50  # 你希望在平均趋势图上采样的点数,可以调整
 | 
						||
    common_normalized_timeline = np.linspace(0, 1, num_time_points)
 | 
						||
    nodes_in_bins = {i: 0 for i in range(len(common_normalized_timeline) - 1)}
 | 
						||
 | 
						||
    interpolated_counts_all_cascades = []
 | 
						||
    cascades_total = 0
 | 
						||
    t0_mean = 0
 | 
						||
    with open(input_file) as file:
 | 
						||
        for line in file:
 | 
						||
            cascades_total += 1
 | 
						||
            if cascades_total > args.obs_cas_num and args.obs_cas_num >0:
 | 
						||
                break
 | 
						||
            parts = line.split(',')
 | 
						||
            cascade_id = parts[0]
 | 
						||
 | 
						||
            paths = parts[1:]
 | 
						||
            t_max = 0
 | 
						||
            t_min = float('inf')
 | 
						||
            cas_set = set()
 | 
						||
            times = []
 | 
						||
            counts = []
 | 
						||
            for p in paths:
 | 
						||
                nodes = p.split(':')[0].split('/')
 | 
						||
                time_now = int(p.split(':')[1])
 | 
						||
                if time_now > t_max:
 | 
						||
                    t_max = time_now
 | 
						||
                if time_now < t_min:
 | 
						||
                    t_min = time_now
 | 
						||
                node = nodes[-1]
 | 
						||
                node_id = node
 | 
						||
                cas_set.add(node_id) 
 | 
						||
                times.append(time_now)
 | 
						||
                counts.append(len(cas_set))
 | 
						||
            
 | 
						||
            sorted_times = sorted(times)
 | 
						||
            if t_max - t_min <= 0:
 | 
						||
                continue
 | 
						||
            nodes_in_bin = {i: 0 for i in range(len(common_normalized_timeline) - 1)}
 | 
						||
            for i in range(len(times)):
 | 
						||
                participation_time = (sorted_times[i] - t_min)/(t_max - t_min)
 | 
						||
                sorted_times[i] = participation_time
 | 
						||
                bin_index = np.digitize(participation_time, common_normalized_timeline, right=False) - 1
 | 
						||
                if bin_index == -1 and participation_time == common_normalized_timeline[0]:
 | 
						||
                    bin_index = 0
 | 
						||
                elif bin_index == len(common_normalized_timeline) - 1:
 | 
						||
                    if participation_time == common_normalized_timeline[-1]:
 | 
						||
                        bin_index = len(common_normalized_timeline) - 2
 | 
						||
                    else:
 | 
						||
                        continue
 | 
						||
                if 0 <= bin_index < len(nodes_in_bin):
 | 
						||
                    nodes_in_bins[bin_index] +=1
 | 
						||
                    nodes_in_bin[bin_index] += 1
 | 
						||
            
 | 
						||
            interpolated_counts = np.interp(common_normalized_timeline, 
 | 
						||
                                        sorted_times, 
 | 
						||
                                        counts)
 | 
						||
            estimated_increments_between_common_times = np.diff(interpolated_counts)
 | 
						||
            interpolated_counts_all_cascades.append(estimated_increments_between_common_times)
 | 
						||
            
 | 
						||
            t0 = find_inflection_logistic(sorted_times, counts)
 | 
						||
            if t0 is None:
 | 
						||
                continue
 | 
						||
            t0_dict[cascade_id] = [t0, (t_max - t_min)/3600]
 | 
						||
            t0_mean += t0
 | 
						||
            
 | 
						||
            
 | 
						||
            
 | 
						||
    t0_mean /= len(t0_dict)
 | 
						||
    stacked_interpolated_counts = np.vstack(interpolated_counts_all_cascades).tolist()
 | 
						||
    counts_0 = np.array([nodes_in_bins[i] for i in sorted(nodes_in_bins.keys())])
 | 
						||
    all_counts = (counts_0 / counts_0.max()).tolist()
 | 
						||
    
 | 
						||
    
 | 
						||
    prepare['t0'] = t0_dict
 | 
						||
    prepare['t0_mean'] = t0_mean
 | 
						||
    prepare['interpolated_counts_sum'] = stacked_interpolated_counts
 | 
						||
    prepare['interpolated_counts_step_avg'] = all_counts
 | 
						||
    
 | 
						||
    try:
 | 
						||
        with open(json_file, 'w', encoding='utf-8') as f:
 | 
						||
            json.dump(prepare, f, indent=4, ensure_ascii=False)
 | 
						||
        print(f"t0_dict已成功保存到 JSON 文件: {json_file}")
 | 
						||
    except TypeError as e:
 | 
						||
        print(f"错误: 字典中可能包含无法序列化为 JSON 的数据类型。错误: {e}")
 | 
						||
    except IOError:
 | 
						||
        print(f"错误: 无法打开或写入文件 '{json_file}'。")
 | 
						||
    except Exception as e:
 | 
						||
        print(f"保存 JSON 文件时发生未知错误: {e}")
 | 
						||
        
 | 
						||
    return prepare
 | 
						||
                
 | 
						||
 | 
						||
def analyse_result(node_list, user_event_dict, prepare, args, set_name):
 | 
						||
    result_path = args.result_path
 | 
						||
    t0_dict = prepare['t0']
 | 
						||
    t0_mean = prepare['t0_mean']
 | 
						||
    anchors_activation_times_before_t0_ = defaultdict(list)
 | 
						||
    anchors_activation_times_before_t0 = defaultdict(list)
 | 
						||
    anchors_activation_times_ = defaultdict(list)
 | 
						||
    anchors_activation_times = defaultdict(list)
 | 
						||
    cas_set = set()
 | 
						||
    for k,v in user_event_dict.items():
 | 
						||
        if k in node_list:
 | 
						||
            for cas_id, ac_time in v.items():
 | 
						||
                cas_set.add(cas_id)
 | 
						||
                if cas_id not in t0_dict:
 | 
						||
                    continue
 | 
						||
                anchors_activation_times_before_t0_[k].append(ac_time - t0_dict[cas_id][0])
 | 
						||
                anchors_activation_times_before_t0[k].append((ac_time - t0_dict[cas_id][0])*t0_dict[cas_id][1])
 | 
						||
                anchors_activation_times_[k].append(ac_time)
 | 
						||
                anchors_activation_times[k].append(ac_time*t0_dict[cas_id][1])
 | 
						||
    print(f"感知到的级联数量:{len(cas_set)}")
 | 
						||
    
 | 
						||
    anchors_avg_activation_time_before_t0 = dict()
 | 
						||
    anchors_avg_activation_time = dict()
 | 
						||
    
 | 
						||
    for node_id, times in anchors_activation_times_before_t0.items():
 | 
						||
        if times:
 | 
						||
            anchors_avg_activation_time_before_t0[node_id] = (sum(times) / len(times))
 | 
						||
        else:
 | 
						||
            anchors_avg_activation_time_before_t0[node_id] = None
 | 
						||
    
 | 
						||
    for node_id, times in anchors_activation_times.items():
 | 
						||
        if times:
 | 
						||
            anchors_avg_activation_time[node_id] = (sum(times) / len(times))
 | 
						||
        else:
 | 
						||
            anchors_avg_activation_time[node_id] = None
 | 
						||
            
 | 
						||
    anchors_avg_activation_time_before_t0_ = dict()
 | 
						||
    anchors_avg_activation_time_ = dict()
 | 
						||
    
 | 
						||
    for node_id, times in anchors_activation_times_before_t0_.items():
 | 
						||
        if times:
 | 
						||
            anchors_avg_activation_time_before_t0_[node_id] = (sum(times) / len(times))
 | 
						||
        else:
 | 
						||
            anchors_avg_activation_time_before_t0_[node_id] = None
 | 
						||
    
 | 
						||
    for node_id, times in anchors_activation_times_.items():
 | 
						||
        if times:
 | 
						||
            anchors_avg_activation_time_[node_id] = (sum(times) / len(times))
 | 
						||
        else:
 | 
						||
            anchors_avg_activation_time_[node_id] = None
 | 
						||
            
 | 
						||
    anchors_total_avg_time_before_t0 = sum(t for t in anchors_avg_activation_time_before_t0.values() if t is not None) / len(anchors_avg_activation_time_before_t0)
 | 
						||
    anchors_total_avg_time = sum(t for t in anchors_avg_activation_time.values() if t is not None) / len(anchors_avg_activation_time)
 | 
						||
    
 | 
						||
    
 | 
						||
    anchors_total_avg_time_before_t0_ = sum(t for t in anchors_avg_activation_time_before_t0_.values() if t is not None) / len(anchors_avg_activation_time_before_t0_)
 | 
						||
    anchors_total_avg_time_ = sum(t for t in anchors_avg_activation_time_.values() if t is not None) / len(anchors_avg_activation_time_)
 | 
						||
    result_path = os.path.join(result_path, f'{set_name}')
 | 
						||
    os.makedirs(result_path, exist_ok=True)
 | 
						||
    
 | 
						||
    anchor_event_dict = {}
 | 
						||
    for anchor in node_list:
 | 
						||
        anchor_event_dict[anchor] = user_event_dict[anchor]
 | 
						||
    
 | 
						||
    anchor_event_file = os.path.join(result_path, f'anchor_event_{args.anchor_budget}.json')
 | 
						||
    try:
 | 
						||
        with open(anchor_event_file, 'w', encoding='utf-8') as f:
 | 
						||
            json.dump(anchor_event_dict, f, indent=4, ensure_ascii=False)
 | 
						||
        print(f"anchor_event_dictt已成功保存到 JSON 文件: {anchor_event_file}")
 | 
						||
    except TypeError as e:
 | 
						||
        print(f"错误: 字典中可能包含无法序列化为 JSON 的数据类型。错误: {e}")
 | 
						||
    except IOError:
 | 
						||
        print(f"错误: 无法打开或写入文件 '{anchor_event_file}'。")
 | 
						||
    except Exception as e:
 | 
						||
        print(f"保存 JSON 文件时发生未知错误: {e}")
 | 
						||
    
 | 
						||
    
 | 
						||
    anylyze_result = {}
 | 
						||
    anylyze_result['anchors_act_cas_num'] = len(cas_set)
 | 
						||
    anylyze_result['act_cas'] = list(cas_set)
 | 
						||
    anylyze_result['anchors_total_avg_time_before_t0'] = anchors_total_avg_time_before_t0
 | 
						||
    anylyze_result['anchors_avg_activation_time_before_t0'] = anchors_avg_activation_time_before_t0
 | 
						||
    anylyze_result['anchors_total_avg_time_before_t0_'] = anchors_total_avg_time_before_t0_
 | 
						||
    anylyze_result['anchors_avg_activation_time_before_t0_'] = anchors_avg_activation_time_before_t0_
 | 
						||
    anylyze_result['anchors_total_avg_time'] = anchors_total_avg_time
 | 
						||
    anylyze_result['anchors_avg_activation'] = anchors_avg_activation_time
 | 
						||
    anylyze_result['anchors_total_avg_time_'] = anchors_total_avg_time_
 | 
						||
    anylyze_result['anchors_avg_activation_'] = anchors_avg_activation_time_
 | 
						||
    
 | 
						||
    json_file = os.path.join(result_path, f'anylyze_result_{args.anchor_budget}.json')
 | 
						||
    try:
 | 
						||
        with open(json_file, 'w', encoding='utf-8') as f:
 | 
						||
            json.dump(anylyze_result, f, indent=4, ensure_ascii=False)
 | 
						||
        print(f"anylyze_result已成功保存到 JSON 文件: {json_file}")
 | 
						||
    except TypeError as e:
 | 
						||
        print(f"错误: 字典中可能包含无法序列化为 JSON 的数据类型。错误: {e}")
 | 
						||
    except IOError:
 | 
						||
        print(f"错误: 无法打开或写入文件 '{json_file}'。")
 | 
						||
    except Exception as e:
 | 
						||
        print(f"保存 JSON 文件时发生未知错误: {e}")
 | 
						||
    
 | 
						||
    print("开始绘制激活时间图 ")
 | 
						||
    # --- 绘制激活时间图 ---
 | 
						||
    important_times = [t for t in anchors_avg_activation_time_before_t0_.values() if t is not None]
 | 
						||
    # 然后继续绘制图表
 | 
						||
    # 假设已经有 important_times 和 random_times
 | 
						||
    plt.figure(figsize=(10, 6))
 | 
						||
    sns.kdeplot(important_times, label='anchors', fill=True, color='red', linewidth=2)
 | 
						||
    
 | 
						||
    plt.title('Anchors activate time')
 | 
						||
    plt.xlabel('Time')
 | 
						||
    plt.ylabel('P')
 | 
						||
    plt.legend()
 | 
						||
    plt.grid(True)
 | 
						||
    plt.tight_layout()
 | 
						||
 | 
						||
    fig_file_1 = os.path.join(result_path, f'anchor_act_time_{args.anchor_budget}.png')
 | 
						||
    plt.savefig(fig_file_1,  # 文件名
 | 
						||
                dpi=300,          # 可选:分辨率 (dots per inch)
 | 
						||
                bbox_inches='tight',# 可选:尝试裁剪掉空白边缘
 | 
						||
               )
 | 
						||
    plt.close()
 | 
						||
    
 | 
						||
    print("开始绘制趋势图 ")
 | 
						||
    
 | 
						||
    # --- 绘制趋势图 ---
 | 
						||
    min_time = 0.0
 | 
						||
    max_time = 1 # 假设最大观察时间是 0.5,你可以根据你的数据调整
 | 
						||
    bin_width = 0.01
 | 
						||
    num_bins = int(np.ceil((max_time - min_time) / bin_width)) # 向上取整计算bin的数量
 | 
						||
 | 
						||
    # 创建时间段的边界
 | 
						||
    # 例如:[0.0, 0.1, 0.2, 0.3, 0.4, 0.5]
 | 
						||
    time_bin_edges = np.linspace(min_time, max_time, num_bins + 1)
 | 
						||
    
 | 
						||
    nodes_in_bins = {i: 0 for i in range(len(time_bin_edges) - 1)}
 | 
						||
    for node, participation_time in anchors_avg_activation_time_.items():
 | 
						||
        bin_index = np.digitize(participation_time, time_bin_edges, right=False) - 1
 | 
						||
        if bin_index == -1 and participation_time == time_bin_edges[0]:
 | 
						||
            bin_index = 0
 | 
						||
        elif bin_index == len(time_bin_edges) - 1:
 | 
						||
            if participation_time == time_bin_edges[-1]:
 | 
						||
                bin_index = len(time_bin_edges) - 2
 | 
						||
            else:
 | 
						||
                continue
 | 
						||
        if 0 <= bin_index < len(nodes_in_bins):
 | 
						||
            nodes_in_bins[bin_index] += 1
 | 
						||
    
 | 
						||
    
 | 
						||
    plt.figure(figsize=(10, 6))  
 | 
						||
    
 | 
						||
    bin_labels = []
 | 
						||
    for i in range(len(time_bin_edges) - 1):
 | 
						||
        start_time = time_bin_edges[i]
 | 
						||
        end_time = time_bin_edges[i+1]
 | 
						||
        # 你可以用区间的字符串表示,或者区间的中心点
 | 
						||
        bin_labels.append(f"[{start_time:.2f}-{end_time:.2f})")
 | 
						||
        # bin_labels.append(f"{start_time:.1f}-") # 更简洁的标签,只显示起始
 | 
						||
 | 
						||
    # 获取每个 bin 的计数值
 | 
						||
    counts_0 = np.array([nodes_in_bins[i] for i in sorted(nodes_in_bins.keys())])
 | 
						||
    counts = (counts_0 / counts_0.max()).tolist()
 | 
						||
    # X 轴的位置 (用于条形图)
 | 
						||
    x_positions = time_bin_edges[:-1] + bin_width/2
 | 
						||
 | 
						||
    # --- 使用 plt.axvline() 绘制垂直红线 ---
 | 
						||
    plt.axvline(x=t0_mean,  # 线的 x 坐标
 | 
						||
                color='red',              # 线的颜色
 | 
						||
                linestyle='--',           # 线型 (例如 'solid', '--', '-.', ':')
 | 
						||
                linewidth=2,              # 线的宽度
 | 
						||
                label=f'Average outbreak time = {t0_mean}') 
 | 
						||
    fig_file_2 = os.path.join(result_path, f'trend_{args.anchor_budget}.png')
 | 
						||
    plt.savefig(fig_file_2,  # 文件名
 | 
						||
                dpi=300,          # 可选:分辨率 (dots per inch)
 | 
						||
                bbox_inches='tight',# 可选:尝试裁剪掉空白边缘
 | 
						||
               )
 | 
						||
    plt.close()
 | 
						||
    return
 | 
						||
 | 
						||
def parse_args():
 | 
						||
    parser = argparse.ArgumentParser(description='Parameters')
 | 
						||
    parser.add_argument('--input_file', default='./dataset_for_anchor.txt', type=str, help='Cascade file')
 | 
						||
    parser.add_argument('--result_path', default='./result/', type=str, help='result save path')
 | 
						||
    parser.add_argument('--anchor_budget', default=0.005, type=float, help='Anchors num')
 | 
						||
    parser.add_argument('--anchor_num', default=1000, type=int, help='Anchors num')
 | 
						||
    parser.add_argument('--obs_cas_num', default=-1, type=int, help='Anchors observe cas num')
 | 
						||
    parser.add_argument('--obs_time', default=-1, type=int, help='Anchors observe time, default seeting is -1, meaning can observe all')
 | 
						||
    parser.add_argument('--key_node_file', default='/home/zjy/project/CasDO/dataset/nh/key_node_10', type=str, help='key_node_json save file')
 | 
						||
    parser.add_argument('--anchor_screen', default=-0.8, type=float, help='anchor_screen')
 | 
						||
    parser.add_argument('--cas_min_length', default=20, type=int, help='Cas min length')
 | 
						||
    parser.add_argument('--t_max', default=1751299200, type=int, help='Cas min length')
 | 
						||
    parser.add_argument('--t_min', default=1751299200, type=int, help='Cas min length')
 | 
						||
    return parser.parse_args()
 | 
						||
 | 
						||
if __name__ == '__main__':
 | 
						||
    args = parse_args()
 | 
						||
    
 | 
						||
    #===========执行m点挖掘算法, 返回m点以及感知的级联数量=================
 | 
						||
    A, all_cas, user_event_dict = generate_anchors(args)
 | 
						||
    
 | 
						||
    
 | 
						||
    
 | 
						||
    #===========分析结果=================
 | 
						||
    print("开始分析结果")
 | 
						||
    prepare = get_t0(args)
 | 
						||
    set_name = 'anchor'
 | 
						||
    analyse_result(A, user_event_dict, prepare, args, set_name)
 | 
						||
    
 | 
						||
    
 | 
						||
    
 | 
						||
    
 | 
						||
    
 | 
						||
    # # # #===========读取其他算法挖掘重要节点=================
 | 
						||
    # with open(args.key_node_file, 'r') as file:
 | 
						||
    #     # 使用列表推导式,strip() 去掉每行的换行符,然后将每行转换为整数
 | 
						||
    #     # 过滤掉可能存在的空行或无效数据
 | 
						||
    #     key_node_list = [line.strip() for line in file if line.strip().isdigit()]
 | 
						||
 | 
						||
    # key_node_set = set(key_node_list[:len(A)])
 | 
						||
    
 | 
						||
    # json_file = os.path.join(args.result_path, 'user_event.json')
 | 
						||
    # user_event_dict = {}
 | 
						||
    # if os.path.exists(json_file):
 | 
						||
    #     print(f"文件 {json_file} 存在。")
 | 
						||
    #     with open(json_file, 'r', encoding='utf-8') as f:
 | 
						||
    #         user_event_dict = json.load(f)
 | 
						||
            
 | 
						||
    # # prepare = get_t0(args)
 | 
						||
    
 | 
						||
    # set_name_d = 'degree'
 | 
						||
    # analyse_result(key_node_set, user_event_dict, prepare, args, set_name_d)
 | 
						||
    
 | 
						||
    
 | 
						||
    
 | 
						||
    
 | 
						||
    
 | 
						||
     |