Compare commits

...

3 Commits

Author SHA1 Message Date
ecb3c705c2 new file: .gitignore 2025-05-28 15:43:29 +08:00
bfcba4660a modified: README.md 2025-05-28 15:41:34 +08:00
128ed7994e modified: README.md
modified:   anchor.py
	deleted:    anchors.txt
2025-05-28 15:40:11 +08:00
4 changed files with 377 additions and 121 deletions

1
.gitignore vendored Normal file
View File

@ -0,0 +1 @@
result/

View File

@ -1,11 +1,8 @@
需要的 python包: numpy 需要的 python包: numpy, seaborn, matplotlib, scipy
执行python anchor.py后自动将计算出的m点保存到本地需要四个参数 执行python anchor.py后自动将计算出的m点保存到本地需要四个参数
input_file 级联数据文件路径; input_file 级联数据文件路径;
result_path 计算出的结果保存路径;
output_file 计算出的m点保存路径
anchor_budget m点预算 anchor_budget m点预算
obs_time 可观测时间,默认为-1可不填写 obs_time 可观测时间,默认为-1可不填写
max_cas_num 处理的事件数量,默认为-1可不填写

386
anchor.py
View File

@ -1,7 +1,16 @@
from collections import defaultdict from collections import defaultdict
import time
import numpy as np import numpy as np
import copy import copy
import argparse import argparse
import json
from scipy.optimize import curve_fit
import matplotlib.pyplot as plt
import numpy as np
import time
from matplotlib import font_manager
import seaborn as sns
import os
def find_max_indices_numpy(L_dict): def find_max_indices_numpy(L_dict):
keys_arr = np.array(list(L_dict.keys())) keys_arr = np.array(list(L_dict.keys()))
@ -73,9 +82,10 @@ def select_social_sensors(R, L_dict, user_event_dict, event_user_dict, b):
print(f"Select social sensors finish! Get social sensors, num: {len(A)}") print(f"Select social sensors finish! Get social sensors, num: {len(A)}")
return A return A, all_cas
def handle_cas(filename, obs_time=-1): def handle_cas(args):
filename, obs_time = args.input_file, args.obs_time
print("Handle cascade begin!") print("Handle cascade begin!")
user_event_dict = defaultdict(dict) user_event_dict = defaultdict(dict)
event_user_dict = defaultdict(list) event_user_dict = defaultdict(list)
@ -84,8 +94,8 @@ def handle_cas(filename, obs_time=-1):
for line in file: for line in file:
cascades_total += 1 cascades_total += 1
# if cascades_total > 100: if cascades_total > args.max_cas_num and args.max_cas_num != -1:
# break break
parts = line.split(',') parts = line.split(',')
cascade_id = parts[0] cascade_id = parts[0]
activation_times = {} activation_times = {}
@ -128,31 +138,379 @@ def handle_cas(filename, obs_time=-1):
return L_dict, user_event_dict, event_user_dict return L_dict, user_event_dict, event_user_dict
def generate_anchors(input_file, output_file, anchor_budget, obs_time = -1): def generate_anchors(args):
result_path, anchor_budget = args.result_path, args.anchor_budget
L_dict, user_event_dict, event_user_dict = handle_cas(input_file, obs_time) L_dict, user_event_dict, event_user_dict = handle_cas(args)
num_nodes = len(L_dict.keys()) num_nodes = len(L_dict.keys())
max_anchor_num = int(num_nodes*0.02)
if anchor_budget > max_anchor_num: anchor_num = 0
if anchor_budget > 0.02:
max_anchor_num = int(num_nodes*0.02)
print(f"Max anchor num is {max_anchor_num}, anchor_budget is set to {max_anchor_num}") print(f"Max anchor num is {max_anchor_num}, anchor_budget is set to {max_anchor_num}")
anchor_budget = max_anchor_num anchor_num = max_anchor_num
else:
A = select_social_sensors(R, L_dict, user_event_dict, event_user_dict, anchor_budget) anchor_num = int(num_nodes*anchor_budget)
A, all_cas = select_social_sensors(R, L_dict, user_event_dict, event_user_dict, anchor_num)
os.makedirs(result_path, exist_ok=True)
output_file = os.path.join(result_path, f'anchors_{args.anchor_budget}.txt')
with open(output_file, 'w') as file: with open(output_file, 'w') as file:
for item in A: for item in A:
file.write(f"{item}\n") file.write(f"{item}\n")
return A, all_cas, user_event_dict
# Logistic 函数定义
def logistic(t, K, r, t0):
return K / (1 + np.exp(-r * (t - t0)))
def find_inflection_logistic(t_list, c_list):
t_array = np.array(t_list)
c_array = np.array(c_list)
# 时间归一化到 [0, 1]
t_min = t_array.min()
t_max = t_array.max()
if len(set(c_array)) < 3 or max(c_array) <= 1 or t_max - t_min == 0:
return None
t_array_normalized = (t_array - t_min) / (t_max - t_min)
try:
p0 = [max(c_array) if max(c_array) > 0 else 1.0, 1.0, 0.5] # 确保 K 初始值不为0如果c_array全是0
# 检查 c_array 是否有效
if max(c_array) <= 0: # 如果c_array都是0或负数也无法拟合S曲线
print("警告: 累计值 (c_list) 没有正增长无法拟合S曲线。")
return None
# 调整 bounds如果 K 的下限是0max(c_array)也可能是0可能导致问题
# bounds = ([0, 0, 0], [np.inf, 10, 1]) # K r t0
# K > 0 (所以下限设为很小的正数,或基于数据)
# r > 0 (增长率)
# t0 在 [0, 1] 范围内
bounds = ([1e-6, 1e-6, 0], [np.inf, 10, 1]) # K, r 应该为正
popt, _ = curve_fit(logistic, t_array_normalized, c_array, p0=p0, bounds=bounds, maxfev=10000)
K, r, t0 = popt
return t0
except (RuntimeError, ValueError):
print(f"警告: 曲线拟合失败或参数无效")
return None
# 获取所有事件的爆发时间
def get_t0(args):
input_file, result_path = args.input_file, args.result_path
print("To get All Cas' T0")
os.makedirs(result_path, exist_ok=True)
json_file = os.path.join(result_path, 'prepare.json')
prepare = {}
if os.path.exists(json_file):
print(f"文件 {json_file} 存在。")
with open(json_file, 'r', encoding='utf-8') as f:
prepare = json.load(f)
return prepare
t0_dict = {}
num_time_points = 50 # 你希望在平均趋势图上采样的点数,可以调整
common_normalized_timeline = np.linspace(0, 1, num_time_points)
nodes_in_bins = {i: 0 for i in range(len(common_normalized_timeline) - 1)}
interpolated_counts_all_cascades = []
cascades_total = 0
t0_mean = 0
with open(input_file) as file:
for line in file:
cascades_total += 1
if cascades_total > args.max_cas_num and args.max_cas_num != -1:
break
parts = line.split(',')
cascade_id = parts[0]
paths = parts[1:]
t_max = 0
t_min = float('inf')
cas_set = set()
times = []
counts = []
for p in paths:
nodes = p.split(':')[0].split('/')
time_now = int(p.split(':')[1])
if time_now > t_max:
t_max = time_now
if time_now < t_min:
t_min = time_now
node = nodes[-1]
node_id = int(node)
cas_set.add(node_id)
times.append(time_now)
counts.append(len(cas_set))
sorted_times = sorted(times)
if t_max - t_min <= 0:
continue
nodes_in_bin = {i: 0 for i in range(len(common_normalized_timeline) - 1)}
for i in range(len(times)):
participation_time = (sorted_times[i] - t_min)/(t_max - t_min)
sorted_times[i] = participation_time
bin_index = np.digitize(participation_time, common_normalized_timeline, right=False) - 1
if bin_index == -1 and participation_time == common_normalized_timeline[0]:
bin_index = 0
elif bin_index == len(common_normalized_timeline) - 1:
if participation_time == common_normalized_timeline[-1]:
bin_index = len(common_normalized_timeline) - 2
else:
continue
if 0 <= bin_index < len(nodes_in_bin):
nodes_in_bins[bin_index] +=1
nodes_in_bin[bin_index] += 1
interpolated_counts = np.interp(common_normalized_timeline,
sorted_times,
counts)
estimated_increments_between_common_times = np.diff(interpolated_counts)
interpolated_counts_all_cascades.append(estimated_increments_between_common_times)
t0 = find_inflection_logistic(sorted_times, counts)
if t0 is None:
continue
t0_dict[cascade_id] = t0
t0_mean += t0
t0_mean /= len(t0_dict)
stacked_interpolated_counts = np.vstack(interpolated_counts_all_cascades).tolist()
counts_0 = np.array([nodes_in_bins[i] for i in sorted(nodes_in_bins.keys())])
all_counts = (counts_0 / counts_0.max()).tolist()
prepare['t0'] = t0_dict
prepare['t0_mean'] = t0_mean
prepare['interpolated_counts_sum'] = stacked_interpolated_counts
prepare['interpolated_counts_step_avg'] = all_counts
try:
with open(json_file, 'w', encoding='utf-8') as f:
json.dump(prepare, f, indent=4, ensure_ascii=False)
print(f"t0_dict已成功保存到 JSON 文件: {json_file}")
except TypeError as e:
print(f"错误: 字典中可能包含无法序列化为 JSON 的数据类型。错误: {e}")
except IOError:
print(f"错误: 无法打开或写入文件 '{json_file}'")
except Exception as e:
print(f"保存 JSON 文件时发生未知错误: {e}")
return prepare
def analyse_result(node_list, user_event_dict, prepare, args):
result_path = args.result_path
t0_dict = prepare['t0']
t0_mean = prepare['t0_mean']
anchors_activation_times_before_t0 = defaultdict(list)
anchors_activation_times = defaultdict(list)
cas_set = set()
for k,v in user_event_dict.items():
if k in node_list:
for cas_id, ac_time in v.items():
cas_set.add(cas_id)
if cas_id not in t0_dict:
continue
anchors_activation_times_before_t0[k].append(ac_time - t0_dict[cas_id])
anchors_activation_times[k].append(ac_time)
anchors_avg_activation_time_before_t0 = dict()
anchors_avg_activation_time = dict()
for node_id, times in anchors_activation_times_before_t0.items():
if times:
anchors_avg_activation_time_before_t0[node_id] = (sum(times) / len(times))
else:
anchors_avg_activation_time_before_t0[node_id] = None
for node_id, times in anchors_activation_times.items():
if times:
anchors_avg_activation_time[node_id] = (sum(times) / len(times))
else:
anchors_avg_activation_time[node_id] = None
anchors_total_avg_time_before_t0 = sum(t for t in anchors_avg_activation_time_before_t0.values() if t is not None) / len(anchors_avg_activation_time_before_t0)
anchors_total_avg_time = sum(t for t in anchors_avg_activation_time.values() if t is not None) / len(anchors_avg_activation_time)
os.makedirs(result_path, exist_ok=True)
anylyze_result = {}
anylyze_result['anchors_act_cas_num'] = len(cas_set)
anylyze_result['act_cas'] = list(cas_set)
anylyze_result['anchors_total_avg_time_before_t0'] = anchors_total_avg_time_before_t0
anylyze_result['anchors_avg_activation_time_before_t0'] = anchors_avg_activation_time_before_t0
anylyze_result['anchors_total_avg_time'] = anchors_total_avg_time
anylyze_result['anchors_avg_activation'] = anchors_avg_activation_time
json_file = os.path.join(result_path, f'anylyze_result_{args.anchor_budget}.json')
try:
with open(json_file, 'w', encoding='utf-8') as f:
json.dump(anylyze_result, f, indent=4, ensure_ascii=False)
print(f"anylyze_result已成功保存到 JSON 文件: {json_file}")
except TypeError as e:
print(f"错误: 字典中可能包含无法序列化为 JSON 的数据类型。错误: {e}")
except IOError:
print(f"错误: 无法打开或写入文件 '{json_file}'")
except Exception as e:
print(f"保存 JSON 文件时发生未知错误: {e}")
print("开始绘制激活时间图 ")
# --- 绘制激活时间图 ---
important_times = [t for t in anchors_avg_activation_time_before_t0.values() if t is not None]
# 然后继续绘制图表
# 假设已经有 important_times 和 random_times
plt.figure(figsize=(10, 6))
sns.kdeplot(important_times, label='anchors', fill=True, color='red', linewidth=2)
plt.title('Anchors activate time')
plt.xlabel('Time')
plt.ylabel('P')
plt.legend()
plt.grid(True)
plt.tight_layout()
fig_file_1 = os.path.join(result_path, f'anchor_act_time_{args.anchor_budget}.png')
plt.savefig(fig_file_1, # 文件名
dpi=300, # 可选:分辨率 (dots per inch)
bbox_inches='tight',# 可选:尝试裁剪掉空白边缘
)
plt.close()
print("开始绘制趋势图 ")
# --- 绘制趋势图 ---
min_time = 0.0
max_time = 1 # 假设最大观察时间是 0.5,你可以根据你的数据调整
bin_width = 0.01
num_bins = int(np.ceil((max_time - min_time) / bin_width)) # 向上取整计算bin的数量
# 创建时间段的边界
# 例如:[0.0, 0.1, 0.2, 0.3, 0.4, 0.5]
time_bin_edges = np.linspace(min_time, max_time, num_bins + 1)
nodes_in_bins = {i: 0 for i in range(len(time_bin_edges) - 1)}
for node, participation_time in anchors_avg_activation_time.items():
bin_index = np.digitize(participation_time, time_bin_edges, right=False) - 1
if bin_index == -1 and participation_time == time_bin_edges[0]:
bin_index = 0
elif bin_index == len(time_bin_edges) - 1:
if participation_time == time_bin_edges[-1]:
bin_index = len(time_bin_edges) - 2
else:
continue
if 0 <= bin_index < len(nodes_in_bins):
nodes_in_bins[bin_index] += 1
plt.figure(figsize=(10, 6))
bin_labels = []
for i in range(len(time_bin_edges) - 1):
start_time = time_bin_edges[i]
end_time = time_bin_edges[i+1]
# 你可以用区间的字符串表示,或者区间的中心点
bin_labels.append(f"[{start_time:.2f}-{end_time:.2f})")
# bin_labels.append(f"{start_time:.1f}-") # 更简洁的标签,只显示起始
# 获取每个 bin 的计数值
counts_0 = np.array([nodes_in_bins[i] for i in sorted(nodes_in_bins.keys())])
counts = (counts_0 / counts_0.max()).tolist()
# X 轴的位置 (用于条形图)
x_positions = time_bin_edges[:-1] + bin_width/2
bars = plt.bar(x_positions, counts,
width=bin_width, # 条形的宽度
color='skyblue', # 条形的颜色
edgecolor='black') # 条形的边框颜色
# bar_index = 0
# for bar in bars:
# yval = bar.get_height()
# if yval > 0: # 只为非零的条形添加文本
# plt.text(bar.get_x() + bar.get_width()/2.0, yval + 0.05, # 文本位置微调
# counts_0[bar_index],
# ha='center', # 水平居中
# va='bottom') # 垂直对齐方式
# bar_index +=1
# stacked_interpolated_counts = np.array(prepare['interpolated_counts'])
# stacked_log_counts = stacked_interpolated_counts
# # 沿着级联的维度 (axis=0) 计算平均值
# average_counts_trend = np.mean(stacked_log_counts, axis=0) # (num_time_points,)
# max_counts = average_counts_trend.max()
# std_counts_trend = np.std(stacked_log_counts, axis=0)
# sem_counts_trend = std_counts_trend / np.sqrt( average_counts_trend.shape[0])
# # 计算95%置信区间 (使用1.96作为Z分数)
# confidence_interval_upper = average_counts_trend + 1.96 * sem_counts_trend
# confidence_interval_lower = average_counts_trend - 1.96 * sem_counts_trend
# average_counts_trend_normalized = average_counts_trend / max_counts
# confidence_interval_upper_normalized = confidence_interval_upper / max_counts
# confidence_interval_lower_normalized = confidence_interval_lower / max_counts
# confidence_interval_lower_corrected = np.maximum(0, confidence_interval_lower_normalized)
# common_normalized_timeline = np.linspace(0, 1, average_counts_trend.shape[0])
# plt.plot(common_normalized_timeline, average_counts_trend_normalized, label='Average Participation Trend', color='blue', linewidth=2)
# # (可选) 绘制置信区间或标准差区域
# plt.fill_between(common_normalized_timeline,
# confidence_interval_lower_corrected,
# confidence_interval_upper_normalized,
# color='blue', alpha=0.2, label='95% Confidence Interval')
# counts_step = prepare['interpolated_counts_step_avg']
# common_normalized_timeline = np.linspace(0, 1, len(counts_step))
# plt.plot(common_normalized_timeline, counts_step, label='Average Participation Trend', color='blue', linewidth=2)
# --- 使用 plt.axvline() 绘制垂直红线 ---
plt.axvline(x=t0_mean, # 线的 x 坐标
color='red', # 线的颜色
linestyle='--', # 线型 (例如 'solid', '--', '-.', ':')
linewidth=2, # 线的宽度
label=f'Average outbreak time = {t0_mean}')
fig_file_2 = os.path.join(result_path, f'trend_{args.anchor_budget}.png')
plt.savefig(fig_file_2, # 文件名
dpi=300, # 可选:分辨率 (dots per inch)
bbox_inches='tight',# 可选:尝试裁剪掉空白边缘
)
plt.close()
return
def parse_args(): def parse_args():
parser = argparse.ArgumentParser(description='Parameters') parser = argparse.ArgumentParser(description='Parameters')
parser.add_argument('--input_file', default='./dataset_for_anchor.txt', type=str, help='Cascade file') parser.add_argument('--input_file', default='./dataset_for_anchor.txt', type=str, help='Cascade file')
parser.add_argument('--output_file', default='./anchors.txt', type=str, help='Anchors save file') parser.add_argument('--result_path', default='./result/', type=str, help='result save path')
parser.add_argument('--anchor_budget', default=100, type=int, help='Anchors num') parser.add_argument('--anchor_budget', default=0.005, type=float, help='Anchors num')
parser.add_argument('--max_cas_num', default=-1, type=int, help='Cascade num')
parser.add_argument('--obs_time', default=-1, type=int, help='Anchors observe time, default seeting is -1, meaning can observe all') parser.add_argument('--obs_time', default=-1, type=int, help='Anchors observe time, default seeting is -1, meaning can observe all')
parser.add_argument('--key_node_json', default='./key_node.json', type=str, help='key_node_json save file')
return parser.parse_args() return parser.parse_args()
if __name__ == '__main__': if __name__ == '__main__':
args = parse_args() args = parse_args()
generate_anchors(args.input_file, args.output_file, args.anchor_budget, args.obs_time)
#===========执行m点挖掘算法, 返回m点以及感知的级联数量=================
A, all_cas, user_event_dict = generate_anchors(args)
# #===========读取其他算法挖掘重要节点=================
# with open(args.key_node_json, 'r', encoding='utf-8') as f:
# key_node_json = json.load(f) # 从文件对象 f 加载 JSON
# print("JSON 文件成功加载!")
# key_node_json['res']
#===========分析结果=================
print("开始分析结果")
prepare = get_t0(args)
analyse_result(A, user_event_dict, prepare, args)

View File

@ -1,100 +0,0 @@
154624
166401
48133
121862
3083
57356
146451
397331
67102
82463
23074
49186
48167
77357
75822
143413
290884
360517
48199
64077
7250
51293
361568
509538
26212
25221
44166
247942
8841
434324
126616
64158
176812
233648
50356
197814
445625
63161
106701
472794
588514
535781
412904
517867
55029
108281
216314
59135
292613
11527
247568
84756
309528
217373
30500
3366
291121
333115
78139
9026
118597
145744
409428
139100
43877
57711
515444
121205
1917
60286
315281
70552
58265
47514
56728
49565
138654
66976
80291
49575
70058
522157
83390
83394
123331
3524
82897
27604
48086
149465
3034
83419
427485
48094
41956
491493
456174
144371
84982
50168