Compare commits
	
		
			3 Commits
		
	
	
		
			fad840b624
			...
			ecb3c705c2
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| ecb3c705c2 | |||
| bfcba4660a | |||
| 128ed7994e | 
							
								
								
									
										1
									
								
								.gitignore
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										1
									
								
								.gitignore
									
									
									
									
										vendored
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1 @@
 | 
				
			||||||
 | 
					result/
 | 
				
			||||||
| 
						 | 
					@ -1,11 +1,8 @@
 | 
				
			||||||
需要的 python包: numpy
 | 
					需要的 python包: numpy, seaborn, matplotlib, scipy
 | 
				
			||||||
 | 
					
 | 
				
			||||||
执行python anchor.py后,自动将计算出的m点保存到本地,需要四个参数:
 | 
					执行python anchor.py后,自动将计算出的m点保存到本地,需要四个参数:
 | 
				
			||||||
 | 
					 | 
				
			||||||
input_file 级联数据文件路径;
 | 
					input_file 级联数据文件路径;
 | 
				
			||||||
 | 
					result_path 计算出的结果保存路径;
 | 
				
			||||||
output_file 计算出的m点保存路径;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
anchor_budget m点预算;
 | 
					anchor_budget m点预算;
 | 
				
			||||||
 | 
					 | 
				
			||||||
obs_time 可观测时间,默认为-1,可不填写;
 | 
					obs_time 可观测时间,默认为-1,可不填写;
 | 
				
			||||||
 | 
					max_cas_num 处理的事件数量,默认为-1,可不填写;
 | 
				
			||||||
							
								
								
									
										386
									
								
								anchor.py
									
									
									
									
									
								
							
							
						
						
									
										386
									
								
								anchor.py
									
									
									
									
									
								
							| 
						 | 
					@ -1,7 +1,16 @@
 | 
				
			||||||
from collections import defaultdict
 | 
					from collections import defaultdict
 | 
				
			||||||
 | 
					import time
 | 
				
			||||||
import numpy as np
 | 
					import numpy as np
 | 
				
			||||||
import copy
 | 
					import copy
 | 
				
			||||||
import argparse
 | 
					import argparse
 | 
				
			||||||
 | 
					import json
 | 
				
			||||||
 | 
					from scipy.optimize import curve_fit
 | 
				
			||||||
 | 
					import matplotlib.pyplot as plt
 | 
				
			||||||
 | 
					import numpy as np
 | 
				
			||||||
 | 
					import time
 | 
				
			||||||
 | 
					from matplotlib import font_manager
 | 
				
			||||||
 | 
					import seaborn as sns
 | 
				
			||||||
 | 
					import os
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def find_max_indices_numpy(L_dict):    
 | 
					def find_max_indices_numpy(L_dict):    
 | 
				
			||||||
    keys_arr = np.array(list(L_dict.keys()))
 | 
					    keys_arr = np.array(list(L_dict.keys()))
 | 
				
			||||||
| 
						 | 
					@ -73,9 +82,10 @@ def select_social_sensors(R, L_dict, user_event_dict, event_user_dict, b):
 | 
				
			||||||
        
 | 
					        
 | 
				
			||||||
    
 | 
					    
 | 
				
			||||||
    print(f"Select social sensors finish! Get social sensors, num: {len(A)}")
 | 
					    print(f"Select social sensors finish! Get social sensors, num: {len(A)}")
 | 
				
			||||||
    return A
 | 
					    return A, all_cas
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def handle_cas(filename, obs_time=-1):
 | 
					def handle_cas(args):
 | 
				
			||||||
 | 
					    filename, obs_time =  args.input_file,  args.obs_time
 | 
				
			||||||
    print("Handle cascade begin!")
 | 
					    print("Handle cascade begin!")
 | 
				
			||||||
    user_event_dict = defaultdict(dict)
 | 
					    user_event_dict = defaultdict(dict)
 | 
				
			||||||
    event_user_dict = defaultdict(list)
 | 
					    event_user_dict = defaultdict(list)
 | 
				
			||||||
| 
						 | 
					@ -84,8 +94,8 @@ def handle_cas(filename, obs_time=-1):
 | 
				
			||||||
        for line in file:
 | 
					        for line in file:
 | 
				
			||||||
            
 | 
					            
 | 
				
			||||||
            cascades_total += 1
 | 
					            cascades_total += 1
 | 
				
			||||||
            # if cascades_total > 100:
 | 
					            if cascades_total > args.max_cas_num and args.max_cas_num != -1:
 | 
				
			||||||
            #     break
 | 
					                break
 | 
				
			||||||
            parts = line.split(',')
 | 
					            parts = line.split(',')
 | 
				
			||||||
            cascade_id = parts[0]
 | 
					            cascade_id = parts[0]
 | 
				
			||||||
            activation_times = {}
 | 
					            activation_times = {}
 | 
				
			||||||
| 
						 | 
					@ -128,31 +138,379 @@ def handle_cas(filename, obs_time=-1):
 | 
				
			||||||
        
 | 
					        
 | 
				
			||||||
    return L_dict, user_event_dict, event_user_dict
 | 
					    return L_dict, user_event_dict, event_user_dict
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def generate_anchors(input_file, output_file, anchor_budget, obs_time = -1):
 | 
					def generate_anchors(args):
 | 
				
			||||||
 | 
					    result_path, anchor_budget = args.result_path, args.anchor_budget
 | 
				
			||||||
    
 | 
					    
 | 
				
			||||||
    L_dict, user_event_dict, event_user_dict = handle_cas(input_file, obs_time)
 | 
					    L_dict, user_event_dict, event_user_dict = handle_cas(args)
 | 
				
			||||||
    
 | 
					    
 | 
				
			||||||
    num_nodes = len(L_dict.keys())
 | 
					    num_nodes = len(L_dict.keys())
 | 
				
			||||||
    max_anchor_num = int(num_nodes*0.02)
 | 
					 | 
				
			||||||
    
 | 
					    
 | 
				
			||||||
    if anchor_budget > max_anchor_num:
 | 
					    anchor_num = 0
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    if anchor_budget > 0.02:
 | 
				
			||||||
 | 
					        max_anchor_num = int(num_nodes*0.02)
 | 
				
			||||||
        print(f"Max anchor num is {max_anchor_num}, anchor_budget is set to {max_anchor_num}")
 | 
					        print(f"Max anchor num is {max_anchor_num}, anchor_budget is set to {max_anchor_num}")
 | 
				
			||||||
        anchor_budget = max_anchor_num
 | 
					        anchor_num = max_anchor_num
 | 
				
			||||||
    
 | 
					    else:
 | 
				
			||||||
    A = select_social_sensors(R, L_dict, user_event_dict, event_user_dict, anchor_budget)
 | 
					        anchor_num = int(num_nodes*anchor_budget)
 | 
				
			||||||
    
 | 
					    
 | 
				
			||||||
 | 
					    A, all_cas = select_social_sensors(R, L_dict, user_event_dict, event_user_dict, anchor_num)
 | 
				
			||||||
 | 
					    os.makedirs(result_path, exist_ok=True)
 | 
				
			||||||
 | 
					    output_file = os.path.join(result_path, f'anchors_{args.anchor_budget}.txt')
 | 
				
			||||||
    with open(output_file, 'w') as file:
 | 
					    with open(output_file, 'w') as file:
 | 
				
			||||||
        for item in A:
 | 
					        for item in A:
 | 
				
			||||||
            file.write(f"{item}\n")
 | 
					            file.write(f"{item}\n")
 | 
				
			||||||
 | 
					    return A, all_cas, user_event_dict
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Logistic 函数定义
 | 
				
			||||||
 | 
					def logistic(t, K, r, t0):
 | 
				
			||||||
 | 
					    return K / (1 + np.exp(-r * (t - t0)))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def find_inflection_logistic(t_list, c_list):
 | 
				
			||||||
 | 
					    t_array = np.array(t_list)
 | 
				
			||||||
 | 
					    c_array = np.array(c_list)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # 时间归一化到 [0, 1]
 | 
				
			||||||
 | 
					    t_min = t_array.min()
 | 
				
			||||||
 | 
					    t_max = t_array.max()
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    if len(set(c_array)) < 3 or max(c_array) <= 1 or t_max - t_min == 0:
 | 
				
			||||||
 | 
					        return None
 | 
				
			||||||
 | 
					    t_array_normalized = (t_array - t_min) / (t_max - t_min)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    try:
 | 
				
			||||||
 | 
					        p0 = [max(c_array) if max(c_array) > 0 else 1.0, 1.0, 0.5] # 确保 K 初始值不为0,如果c_array全是0
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        # 检查 c_array 是否有效
 | 
				
			||||||
 | 
					        if max(c_array) <= 0: # 如果c_array都是0或负数,也无法拟合S曲线
 | 
				
			||||||
 | 
					             print("警告: 累计值 (c_list) 没有正增长,无法拟合S曲线。")
 | 
				
			||||||
 | 
					             return None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # 调整 bounds,如果 K 的下限是0,max(c_array)也可能是0,可能导致问题
 | 
				
			||||||
 | 
					        # bounds = ([0, 0, 0], [np.inf, 10, 1]) # K r t0
 | 
				
			||||||
 | 
					        # K > 0 (所以下限设为很小的正数,或基于数据)
 | 
				
			||||||
 | 
					        # r > 0 (增长率)
 | 
				
			||||||
 | 
					        # t0 在 [0, 1] 范围内
 | 
				
			||||||
 | 
					        bounds = ([1e-6, 1e-6, 0], [np.inf, 10, 1]) # K, r 应该为正
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        popt, _ = curve_fit(logistic, t_array_normalized, c_array, p0=p0, bounds=bounds, maxfev=10000)
 | 
				
			||||||
 | 
					        K, r, t0 = popt
 | 
				
			||||||
 | 
					        return t0
 | 
				
			||||||
 | 
					    except (RuntimeError, ValueError):
 | 
				
			||||||
 | 
					        print(f"警告: 曲线拟合失败或参数无效")
 | 
				
			||||||
 | 
					        return None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# 获取所有事件的爆发时间
 | 
				
			||||||
 | 
					def get_t0(args):
 | 
				
			||||||
 | 
					    input_file, result_path = args.input_file, args.result_path
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    print("To get All Cas' T0")
 | 
				
			||||||
 | 
					    os.makedirs(result_path, exist_ok=True)
 | 
				
			||||||
 | 
					    json_file = os.path.join(result_path, 'prepare.json')
 | 
				
			||||||
 | 
					    prepare = {}
 | 
				
			||||||
 | 
					    if os.path.exists(json_file):
 | 
				
			||||||
 | 
					        print(f"文件 {json_file} 存在。")
 | 
				
			||||||
 | 
					        with open(json_file, 'r', encoding='utf-8') as f:
 | 
				
			||||||
 | 
					            prepare = json.load(f)
 | 
				
			||||||
 | 
					        return prepare
 | 
				
			||||||
 | 
					    t0_dict = {}
 | 
				
			||||||
 | 
					    num_time_points = 50  # 你希望在平均趋势图上采样的点数,可以调整
 | 
				
			||||||
 | 
					    common_normalized_timeline = np.linspace(0, 1, num_time_points)
 | 
				
			||||||
 | 
					    nodes_in_bins = {i: 0 for i in range(len(common_normalized_timeline) - 1)}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    interpolated_counts_all_cascades = []
 | 
				
			||||||
 | 
					    cascades_total = 0
 | 
				
			||||||
 | 
					    t0_mean = 0
 | 
				
			||||||
 | 
					    with open(input_file) as file:
 | 
				
			||||||
 | 
					        for line in file:
 | 
				
			||||||
 | 
					            cascades_total += 1
 | 
				
			||||||
 | 
					            if cascades_total > args.max_cas_num and args.max_cas_num != -1:
 | 
				
			||||||
 | 
					                break
 | 
				
			||||||
 | 
					            parts = line.split(',')
 | 
				
			||||||
 | 
					            cascade_id = parts[0]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            paths = parts[1:]
 | 
				
			||||||
 | 
					            t_max = 0
 | 
				
			||||||
 | 
					            t_min = float('inf')
 | 
				
			||||||
 | 
					            cas_set = set()
 | 
				
			||||||
 | 
					            times = []
 | 
				
			||||||
 | 
					            counts = []
 | 
				
			||||||
 | 
					            for p in paths:
 | 
				
			||||||
 | 
					                nodes = p.split(':')[0].split('/')
 | 
				
			||||||
 | 
					                time_now = int(p.split(':')[1])
 | 
				
			||||||
 | 
					                if time_now > t_max:
 | 
				
			||||||
 | 
					                    t_max = time_now
 | 
				
			||||||
 | 
					                if time_now < t_min:
 | 
				
			||||||
 | 
					                    t_min = time_now
 | 
				
			||||||
 | 
					                node = nodes[-1]
 | 
				
			||||||
 | 
					                node_id = int(node)
 | 
				
			||||||
 | 
					                cas_set.add(node_id) 
 | 
				
			||||||
 | 
					                times.append(time_now)
 | 
				
			||||||
 | 
					                counts.append(len(cas_set))
 | 
				
			||||||
 | 
					            
 | 
				
			||||||
 | 
					            sorted_times = sorted(times)
 | 
				
			||||||
 | 
					            if t_max - t_min <= 0:
 | 
				
			||||||
 | 
					                continue
 | 
				
			||||||
 | 
					            nodes_in_bin = {i: 0 for i in range(len(common_normalized_timeline) - 1)}
 | 
				
			||||||
 | 
					            for i in range(len(times)):
 | 
				
			||||||
 | 
					                participation_time = (sorted_times[i] - t_min)/(t_max - t_min)
 | 
				
			||||||
 | 
					                sorted_times[i] = participation_time
 | 
				
			||||||
 | 
					                bin_index = np.digitize(participation_time, common_normalized_timeline, right=False) - 1
 | 
				
			||||||
 | 
					                if bin_index == -1 and participation_time == common_normalized_timeline[0]:
 | 
				
			||||||
 | 
					                    bin_index = 0
 | 
				
			||||||
 | 
					                elif bin_index == len(common_normalized_timeline) - 1:
 | 
				
			||||||
 | 
					                    if participation_time == common_normalized_timeline[-1]:
 | 
				
			||||||
 | 
					                        bin_index = len(common_normalized_timeline) - 2
 | 
				
			||||||
 | 
					                    else:
 | 
				
			||||||
 | 
					                        continue
 | 
				
			||||||
 | 
					                if 0 <= bin_index < len(nodes_in_bin):
 | 
				
			||||||
 | 
					                    nodes_in_bins[bin_index] +=1
 | 
				
			||||||
 | 
					                    nodes_in_bin[bin_index] += 1
 | 
				
			||||||
 | 
					            
 | 
				
			||||||
 | 
					            interpolated_counts = np.interp(common_normalized_timeline, 
 | 
				
			||||||
 | 
					                                        sorted_times, 
 | 
				
			||||||
 | 
					                                        counts)
 | 
				
			||||||
 | 
					            estimated_increments_between_common_times = np.diff(interpolated_counts)
 | 
				
			||||||
 | 
					            interpolated_counts_all_cascades.append(estimated_increments_between_common_times)
 | 
				
			||||||
 | 
					            
 | 
				
			||||||
 | 
					            t0 = find_inflection_logistic(sorted_times, counts)
 | 
				
			||||||
 | 
					            if t0 is None:
 | 
				
			||||||
 | 
					                continue
 | 
				
			||||||
 | 
					            t0_dict[cascade_id] = t0
 | 
				
			||||||
 | 
					            t0_mean += t0
 | 
				
			||||||
 | 
					            
 | 
				
			||||||
 | 
					            
 | 
				
			||||||
 | 
					            
 | 
				
			||||||
 | 
					    t0_mean /= len(t0_dict)
 | 
				
			||||||
 | 
					    stacked_interpolated_counts = np.vstack(interpolated_counts_all_cascades).tolist()
 | 
				
			||||||
 | 
					    counts_0 = np.array([nodes_in_bins[i] for i in sorted(nodes_in_bins.keys())])
 | 
				
			||||||
 | 
					    all_counts = (counts_0 / counts_0.max()).tolist()
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    prepare['t0'] = t0_dict
 | 
				
			||||||
 | 
					    prepare['t0_mean'] = t0_mean
 | 
				
			||||||
 | 
					    prepare['interpolated_counts_sum'] = stacked_interpolated_counts
 | 
				
			||||||
 | 
					    prepare['interpolated_counts_step_avg'] = all_counts
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    try:
 | 
				
			||||||
 | 
					        with open(json_file, 'w', encoding='utf-8') as f:
 | 
				
			||||||
 | 
					            json.dump(prepare, f, indent=4, ensure_ascii=False)
 | 
				
			||||||
 | 
					        print(f"t0_dict已成功保存到 JSON 文件: {json_file}")
 | 
				
			||||||
 | 
					    except TypeError as e:
 | 
				
			||||||
 | 
					        print(f"错误: 字典中可能包含无法序列化为 JSON 的数据类型。错误: {e}")
 | 
				
			||||||
 | 
					    except IOError:
 | 
				
			||||||
 | 
					        print(f"错误: 无法打开或写入文件 '{json_file}'。")
 | 
				
			||||||
 | 
					    except Exception as e:
 | 
				
			||||||
 | 
					        print(f"保存 JSON 文件时发生未知错误: {e}")
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					    return prepare
 | 
				
			||||||
 | 
					                
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def analyse_result(node_list, user_event_dict, prepare, args):
 | 
				
			||||||
 | 
					    result_path = args.result_path
 | 
				
			||||||
 | 
					    t0_dict = prepare['t0']
 | 
				
			||||||
 | 
					    t0_mean = prepare['t0_mean']
 | 
				
			||||||
 | 
					    anchors_activation_times_before_t0 = defaultdict(list)
 | 
				
			||||||
 | 
					    anchors_activation_times = defaultdict(list)
 | 
				
			||||||
 | 
					    cas_set = set()
 | 
				
			||||||
 | 
					    for k,v in user_event_dict.items():
 | 
				
			||||||
 | 
					        if k in node_list:
 | 
				
			||||||
 | 
					            for cas_id, ac_time in v.items():
 | 
				
			||||||
 | 
					                cas_set.add(cas_id)
 | 
				
			||||||
 | 
					                if cas_id not in t0_dict:
 | 
				
			||||||
 | 
					                    continue
 | 
				
			||||||
 | 
					                anchors_activation_times_before_t0[k].append(ac_time - t0_dict[cas_id])
 | 
				
			||||||
 | 
					                anchors_activation_times[k].append(ac_time)
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    anchors_avg_activation_time_before_t0 = dict()
 | 
				
			||||||
 | 
					    anchors_avg_activation_time = dict()
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    for node_id, times in anchors_activation_times_before_t0.items():
 | 
				
			||||||
 | 
					        if times:
 | 
				
			||||||
 | 
					            anchors_avg_activation_time_before_t0[node_id] = (sum(times) / len(times))
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            anchors_avg_activation_time_before_t0[node_id] = None
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    for node_id, times in anchors_activation_times.items():
 | 
				
			||||||
 | 
					        if times:
 | 
				
			||||||
 | 
					            anchors_avg_activation_time[node_id] = (sum(times) / len(times))
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            anchors_avg_activation_time[node_id] = None
 | 
				
			||||||
 | 
					            
 | 
				
			||||||
 | 
					    anchors_total_avg_time_before_t0 = sum(t for t in anchors_avg_activation_time_before_t0.values() if t is not None) / len(anchors_avg_activation_time_before_t0)
 | 
				
			||||||
 | 
					    anchors_total_avg_time = sum(t for t in anchors_avg_activation_time.values() if t is not None) / len(anchors_avg_activation_time)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    os.makedirs(result_path, exist_ok=True)
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    anylyze_result = {}
 | 
				
			||||||
 | 
					    anylyze_result['anchors_act_cas_num'] = len(cas_set)
 | 
				
			||||||
 | 
					    anylyze_result['act_cas'] = list(cas_set)
 | 
				
			||||||
 | 
					    anylyze_result['anchors_total_avg_time_before_t0'] = anchors_total_avg_time_before_t0
 | 
				
			||||||
 | 
					    anylyze_result['anchors_avg_activation_time_before_t0'] = anchors_avg_activation_time_before_t0
 | 
				
			||||||
 | 
					    anylyze_result['anchors_total_avg_time'] = anchors_total_avg_time
 | 
				
			||||||
 | 
					    anylyze_result['anchors_avg_activation'] = anchors_avg_activation_time
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    json_file = os.path.join(result_path, f'anylyze_result_{args.anchor_budget}.json')
 | 
				
			||||||
 | 
					    try:
 | 
				
			||||||
 | 
					        with open(json_file, 'w', encoding='utf-8') as f:
 | 
				
			||||||
 | 
					            json.dump(anylyze_result, f, indent=4, ensure_ascii=False)
 | 
				
			||||||
 | 
					        print(f"anylyze_result已成功保存到 JSON 文件: {json_file}")
 | 
				
			||||||
 | 
					    except TypeError as e:
 | 
				
			||||||
 | 
					        print(f"错误: 字典中可能包含无法序列化为 JSON 的数据类型。错误: {e}")
 | 
				
			||||||
 | 
					    except IOError:
 | 
				
			||||||
 | 
					        print(f"错误: 无法打开或写入文件 '{json_file}'。")
 | 
				
			||||||
 | 
					    except Exception as e:
 | 
				
			||||||
 | 
					        print(f"保存 JSON 文件时发生未知错误: {e}")
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    print("开始绘制激活时间图 ")
 | 
				
			||||||
 | 
					    # --- 绘制激活时间图 ---
 | 
				
			||||||
 | 
					    important_times = [t for t in anchors_avg_activation_time_before_t0.values() if t is not None]
 | 
				
			||||||
 | 
					    # 然后继续绘制图表
 | 
				
			||||||
 | 
					    # 假设已经有 important_times 和 random_times
 | 
				
			||||||
 | 
					    plt.figure(figsize=(10, 6))
 | 
				
			||||||
 | 
					    sns.kdeplot(important_times, label='anchors', fill=True, color='red', linewidth=2)
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    plt.title('Anchors activate time')
 | 
				
			||||||
 | 
					    plt.xlabel('Time')
 | 
				
			||||||
 | 
					    plt.ylabel('P')
 | 
				
			||||||
 | 
					    plt.legend()
 | 
				
			||||||
 | 
					    plt.grid(True)
 | 
				
			||||||
 | 
					    plt.tight_layout()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    fig_file_1 = os.path.join(result_path, f'anchor_act_time_{args.anchor_budget}.png')
 | 
				
			||||||
 | 
					    plt.savefig(fig_file_1,  # 文件名
 | 
				
			||||||
 | 
					                dpi=300,          # 可选:分辨率 (dots per inch)
 | 
				
			||||||
 | 
					                bbox_inches='tight',# 可选:尝试裁剪掉空白边缘
 | 
				
			||||||
 | 
					               )
 | 
				
			||||||
 | 
					    plt.close()
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    print("开始绘制趋势图 ")
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    # --- 绘制趋势图 ---
 | 
				
			||||||
 | 
					    min_time = 0.0
 | 
				
			||||||
 | 
					    max_time = 1 # 假设最大观察时间是 0.5,你可以根据你的数据调整
 | 
				
			||||||
 | 
					    bin_width = 0.01
 | 
				
			||||||
 | 
					    num_bins = int(np.ceil((max_time - min_time) / bin_width)) # 向上取整计算bin的数量
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # 创建时间段的边界
 | 
				
			||||||
 | 
					    # 例如:[0.0, 0.1, 0.2, 0.3, 0.4, 0.5]
 | 
				
			||||||
 | 
					    time_bin_edges = np.linspace(min_time, max_time, num_bins + 1)
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    nodes_in_bins = {i: 0 for i in range(len(time_bin_edges) - 1)}
 | 
				
			||||||
 | 
					    for node, participation_time in anchors_avg_activation_time.items():
 | 
				
			||||||
 | 
					        bin_index = np.digitize(participation_time, time_bin_edges, right=False) - 1
 | 
				
			||||||
 | 
					        if bin_index == -1 and participation_time == time_bin_edges[0]:
 | 
				
			||||||
 | 
					            bin_index = 0
 | 
				
			||||||
 | 
					        elif bin_index == len(time_bin_edges) - 1:
 | 
				
			||||||
 | 
					            if participation_time == time_bin_edges[-1]:
 | 
				
			||||||
 | 
					                bin_index = len(time_bin_edges) - 2
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                continue
 | 
				
			||||||
 | 
					        if 0 <= bin_index < len(nodes_in_bins):
 | 
				
			||||||
 | 
					            nodes_in_bins[bin_index] += 1
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    plt.figure(figsize=(10, 6))  
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    bin_labels = []
 | 
				
			||||||
 | 
					    for i in range(len(time_bin_edges) - 1):
 | 
				
			||||||
 | 
					        start_time = time_bin_edges[i]
 | 
				
			||||||
 | 
					        end_time = time_bin_edges[i+1]
 | 
				
			||||||
 | 
					        # 你可以用区间的字符串表示,或者区间的中心点
 | 
				
			||||||
 | 
					        bin_labels.append(f"[{start_time:.2f}-{end_time:.2f})")
 | 
				
			||||||
 | 
					        # bin_labels.append(f"{start_time:.1f}-") # 更简洁的标签,只显示起始
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # 获取每个 bin 的计数值
 | 
				
			||||||
 | 
					    counts_0 = np.array([nodes_in_bins[i] for i in sorted(nodes_in_bins.keys())])
 | 
				
			||||||
 | 
					    counts = (counts_0 / counts_0.max()).tolist()
 | 
				
			||||||
 | 
					    # X 轴的位置 (用于条形图)
 | 
				
			||||||
 | 
					    x_positions = time_bin_edges[:-1] + bin_width/2
 | 
				
			||||||
 | 
					    bars = plt.bar(x_positions, counts, 
 | 
				
			||||||
 | 
					               width=bin_width,        # 条形的宽度
 | 
				
			||||||
 | 
					               color='skyblue',  # 条形的颜色
 | 
				
			||||||
 | 
					               edgecolor='black') # 条形的边框颜色
 | 
				
			||||||
 | 
					    # bar_index = 0
 | 
				
			||||||
 | 
					    # for bar in bars:
 | 
				
			||||||
 | 
					    #     yval = bar.get_height()
 | 
				
			||||||
 | 
					    #     if yval > 0: # 只为非零的条形添加文本
 | 
				
			||||||
 | 
					    #         plt.text(bar.get_x() + bar.get_width()/2.0, yval + 0.05, # 文本位置微调
 | 
				
			||||||
 | 
					    #                 counts_0[bar_index],      
 | 
				
			||||||
 | 
					    #                 ha='center',    # 水平居中
 | 
				
			||||||
 | 
					    #                 va='bottom')    # 垂直对齐方式
 | 
				
			||||||
 | 
					    #     bar_index +=1
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    # stacked_interpolated_counts = np.array(prepare['interpolated_counts'])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # stacked_log_counts = stacked_interpolated_counts
 | 
				
			||||||
 | 
					    # # 沿着级联的维度 (axis=0) 计算平均值
 | 
				
			||||||
 | 
					    # average_counts_trend = np.mean(stacked_log_counts, axis=0) # (num_time_points,)
 | 
				
			||||||
 | 
					    # max_counts = average_counts_trend.max()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # std_counts_trend = np.std(stacked_log_counts, axis=0)
 | 
				
			||||||
 | 
					    # sem_counts_trend = std_counts_trend / np.sqrt( average_counts_trend.shape[0])
 | 
				
			||||||
 | 
					    # # 计算95%置信区间 (使用1.96作为Z分数)
 | 
				
			||||||
 | 
					    # confidence_interval_upper = average_counts_trend + 1.96 * sem_counts_trend
 | 
				
			||||||
 | 
					    # confidence_interval_lower = average_counts_trend - 1.96 * sem_counts_trend
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    # average_counts_trend_normalized = average_counts_trend / max_counts
 | 
				
			||||||
 | 
					    # confidence_interval_upper_normalized = confidence_interval_upper / max_counts
 | 
				
			||||||
 | 
					    # confidence_interval_lower_normalized = confidence_interval_lower / max_counts
 | 
				
			||||||
 | 
					    # confidence_interval_lower_corrected = np.maximum(0, confidence_interval_lower_normalized)
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    # common_normalized_timeline = np.linspace(0, 1, average_counts_trend.shape[0])
 | 
				
			||||||
 | 
					    # plt.plot(common_normalized_timeline, average_counts_trend_normalized, label='Average Participation Trend', color='blue', linewidth=2)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # # (可选) 绘制置信区间或标准差区域
 | 
				
			||||||
 | 
					    # plt.fill_between(common_normalized_timeline, 
 | 
				
			||||||
 | 
					    #                 confidence_interval_lower_corrected, 
 | 
				
			||||||
 | 
					    #                 confidence_interval_upper_normalized, 
 | 
				
			||||||
 | 
					    #                 color='blue', alpha=0.2, label='95% Confidence Interval')
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    # counts_step = prepare['interpolated_counts_step_avg']
 | 
				
			||||||
 | 
					    # common_normalized_timeline = np.linspace(0, 1, len(counts_step))
 | 
				
			||||||
 | 
					    # plt.plot(common_normalized_timeline, counts_step, label='Average Participation Trend', color='blue', linewidth=2)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # --- 使用 plt.axvline() 绘制垂直红线 ---
 | 
				
			||||||
 | 
					    plt.axvline(x=t0_mean,  # 线的 x 坐标
 | 
				
			||||||
 | 
					                color='red',              # 线的颜色
 | 
				
			||||||
 | 
					                linestyle='--',           # 线型 (例如 'solid', '--', '-.', ':')
 | 
				
			||||||
 | 
					                linewidth=2,              # 线的宽度
 | 
				
			||||||
 | 
					                label=f'Average outbreak time = {t0_mean}') 
 | 
				
			||||||
 | 
					    fig_file_2 = os.path.join(result_path, f'trend_{args.anchor_budget}.png')
 | 
				
			||||||
 | 
					    plt.savefig(fig_file_2,  # 文件名
 | 
				
			||||||
 | 
					                dpi=300,          # 可选:分辨率 (dots per inch)
 | 
				
			||||||
 | 
					                bbox_inches='tight',# 可选:尝试裁剪掉空白边缘
 | 
				
			||||||
 | 
					               )
 | 
				
			||||||
 | 
					    plt.close()
 | 
				
			||||||
 | 
					    return
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def parse_args():
 | 
					def parse_args():
 | 
				
			||||||
    parser = argparse.ArgumentParser(description='Parameters')
 | 
					    parser = argparse.ArgumentParser(description='Parameters')
 | 
				
			||||||
    parser.add_argument('--input_file', default='./dataset_for_anchor.txt', type=str, help='Cascade file')
 | 
					    parser.add_argument('--input_file', default='./dataset_for_anchor.txt', type=str, help='Cascade file')
 | 
				
			||||||
    parser.add_argument('--output_file', default='./anchors.txt', type=str, help='Anchors save file')
 | 
					    parser.add_argument('--result_path', default='./result/', type=str, help='result save path')
 | 
				
			||||||
    parser.add_argument('--anchor_budget', default=100, type=int, help='Anchors num')
 | 
					    parser.add_argument('--anchor_budget', default=0.005, type=float, help='Anchors num')
 | 
				
			||||||
 | 
					    parser.add_argument('--max_cas_num', default=-1, type=int, help='Cascade num')
 | 
				
			||||||
    parser.add_argument('--obs_time', default=-1, type=int, help='Anchors observe time, default seeting is -1, meaning can observe all')
 | 
					    parser.add_argument('--obs_time', default=-1, type=int, help='Anchors observe time, default seeting is -1, meaning can observe all')
 | 
				
			||||||
 | 
					    parser.add_argument('--key_node_json', default='./key_node.json', type=str, help='key_node_json save file')
 | 
				
			||||||
    return parser.parse_args()
 | 
					    return parser.parse_args()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
if __name__ == '__main__':
 | 
					if __name__ == '__main__':
 | 
				
			||||||
    args = parse_args()
 | 
					    args = parse_args()
 | 
				
			||||||
    generate_anchors(args.input_file, args.output_file, args.anchor_budget, args.obs_time)
 | 
					    
 | 
				
			||||||
 | 
					    #===========执行m点挖掘算法, 返回m点以及感知的级联数量=================
 | 
				
			||||||
 | 
					    A, all_cas, user_event_dict = generate_anchors(args)
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    # #===========读取其他算法挖掘重要节点=================
 | 
				
			||||||
 | 
					    # with open(args.key_node_json, 'r', encoding='utf-8') as f:
 | 
				
			||||||
 | 
					    #     key_node_json = json.load(f) # 从文件对象 f 加载 JSON
 | 
				
			||||||
 | 
					    #     print("JSON 文件成功加载!")
 | 
				
			||||||
 | 
					    #     key_node_json['res']
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    #===========分析结果=================
 | 
				
			||||||
 | 
					    print("开始分析结果")
 | 
				
			||||||
 | 
					    prepare = get_t0(args)
 | 
				
			||||||
 | 
					    analyse_result(A, user_event_dict, prepare, args)
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
							
								
								
									
										100
									
								
								anchors.txt
									
									
									
									
									
								
							
							
						
						
									
										100
									
								
								anchors.txt
									
									
									
									
									
								
							| 
						 | 
					@ -1,100 +0,0 @@
 | 
				
			||||||
154624
 | 
					 | 
				
			||||||
166401
 | 
					 | 
				
			||||||
48133
 | 
					 | 
				
			||||||
121862
 | 
					 | 
				
			||||||
3083
 | 
					 | 
				
			||||||
57356
 | 
					 | 
				
			||||||
146451
 | 
					 | 
				
			||||||
397331
 | 
					 | 
				
			||||||
67102
 | 
					 | 
				
			||||||
82463
 | 
					 | 
				
			||||||
23074
 | 
					 | 
				
			||||||
49186
 | 
					 | 
				
			||||||
48167
 | 
					 | 
				
			||||||
77357
 | 
					 | 
				
			||||||
75822
 | 
					 | 
				
			||||||
143413
 | 
					 | 
				
			||||||
290884
 | 
					 | 
				
			||||||
360517
 | 
					 | 
				
			||||||
48199
 | 
					 | 
				
			||||||
64077
 | 
					 | 
				
			||||||
7250
 | 
					 | 
				
			||||||
51293
 | 
					 | 
				
			||||||
361568
 | 
					 | 
				
			||||||
509538
 | 
					 | 
				
			||||||
26212
 | 
					 | 
				
			||||||
25221
 | 
					 | 
				
			||||||
44166
 | 
					 | 
				
			||||||
247942
 | 
					 | 
				
			||||||
8841
 | 
					 | 
				
			||||||
434324
 | 
					 | 
				
			||||||
126616
 | 
					 | 
				
			||||||
64158
 | 
					 | 
				
			||||||
176812
 | 
					 | 
				
			||||||
233648
 | 
					 | 
				
			||||||
50356
 | 
					 | 
				
			||||||
197814
 | 
					 | 
				
			||||||
445625
 | 
					 | 
				
			||||||
63161
 | 
					 | 
				
			||||||
106701
 | 
					 | 
				
			||||||
472794
 | 
					 | 
				
			||||||
588514
 | 
					 | 
				
			||||||
535781
 | 
					 | 
				
			||||||
412904
 | 
					 | 
				
			||||||
517867
 | 
					 | 
				
			||||||
55029
 | 
					 | 
				
			||||||
108281
 | 
					 | 
				
			||||||
216314
 | 
					 | 
				
			||||||
59135
 | 
					 | 
				
			||||||
292613
 | 
					 | 
				
			||||||
11527
 | 
					 | 
				
			||||||
247568
 | 
					 | 
				
			||||||
84756
 | 
					 | 
				
			||||||
309528
 | 
					 | 
				
			||||||
217373
 | 
					 | 
				
			||||||
30500
 | 
					 | 
				
			||||||
3366
 | 
					 | 
				
			||||||
291121
 | 
					 | 
				
			||||||
333115
 | 
					 | 
				
			||||||
78139
 | 
					 | 
				
			||||||
9026
 | 
					 | 
				
			||||||
118597
 | 
					 | 
				
			||||||
145744
 | 
					 | 
				
			||||||
409428
 | 
					 | 
				
			||||||
139100
 | 
					 | 
				
			||||||
43877
 | 
					 | 
				
			||||||
57711
 | 
					 | 
				
			||||||
515444
 | 
					 | 
				
			||||||
121205
 | 
					 | 
				
			||||||
1917
 | 
					 | 
				
			||||||
60286
 | 
					 | 
				
			||||||
315281
 | 
					 | 
				
			||||||
70552
 | 
					 | 
				
			||||||
58265
 | 
					 | 
				
			||||||
47514
 | 
					 | 
				
			||||||
56728
 | 
					 | 
				
			||||||
49565
 | 
					 | 
				
			||||||
138654
 | 
					 | 
				
			||||||
66976
 | 
					 | 
				
			||||||
80291
 | 
					 | 
				
			||||||
49575
 | 
					 | 
				
			||||||
70058
 | 
					 | 
				
			||||||
522157
 | 
					 | 
				
			||||||
83390
 | 
					 | 
				
			||||||
83394
 | 
					 | 
				
			||||||
123331
 | 
					 | 
				
			||||||
3524
 | 
					 | 
				
			||||||
82897
 | 
					 | 
				
			||||||
27604
 | 
					 | 
				
			||||||
48086
 | 
					 | 
				
			||||||
149465
 | 
					 | 
				
			||||||
3034
 | 
					 | 
				
			||||||
83419
 | 
					 | 
				
			||||||
427485
 | 
					 | 
				
			||||||
48094
 | 
					 | 
				
			||||||
41956
 | 
					 | 
				
			||||||
491493
 | 
					 | 
				
			||||||
456174
 | 
					 | 
				
			||||||
144371
 | 
					 | 
				
			||||||
84982
 | 
					 | 
				
			||||||
50168
 | 
					 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user