mdbs/model/anchor_finder.py

186 lines
5.8 KiB
Python
Raw Normal View History

from collections import defaultdict
import numpy as np
import copy
import argparse
import json
import numpy as np
import os
def find_max_indices_numpy(L_dict):
keys_arr = np.array(list(L_dict.keys()))
values_arr = np.array(list(L_dict.values()))
max_val_from_dict = max(L_dict.values()) # 或者 np.max(values_arr)
indices = np.where(values_arr == max_val_from_dict)[0]
max_keys_np_way = keys_arr[indices]
return max_keys_np_way, max_val_from_dict
def R(event_dict):
mean = 0
for _,v in event_dict.items():
mean += -v
mean /= len(event_dict)
return mean
def select_social_sensors(R, L_dict, user_event_dict, event_user_dict, b):
"""
Parameters:
- R: reward function, R(A) returns a numeric value
- L_dict: dict of node -> integer (event num)
- user_event_dict: user -> event -> active_time
- event_user_dict: event -> user_list
- b: budget (numeric)
Returns:
- A: selected set of social sensors
"""
print("Select social sensors begin!")
user_event_dict_ = copy.deepcopy(user_event_dict)
L_dict_ = copy.deepcopy(L_dict)
V = list(L_dict_.keys())
A = []
f = lambda A: len(A)
all_cas = 0
while any(s not in A for s in V) and f(A) < b:
indices_np, max_val_from_dict = find_max_indices_numpy(L_dict_)
if max_val_from_dict == 0:
print("no extra node in cas!!!!!")
break
TAR_set = set(indices_np)
delta = {}
cur = {}
for s in TAR_set:
delta[s] = float('inf')
cur[s] = False
c_star = []
while True:
s_star = max(delta, key=delta.get)
c_star = user_event_dict_[s_star]
if cur[s_star] == True:
if delta[s_star] >= args.anchor_screen:
A.append(s_star)
break
else:
delta[s_star] = R(c_star)
cur[s_star] = True
if delta[s_star] < args.anchor_screen:
L_dict_[s_star] = 0
continue
all_cas += len(c_star)
for cas_id in list(c_star.keys()):
uc = event_user_dict[cas_id]
for v in uc:
if v in L_dict_.keys():
L_dict_[v] -= 1
_ = user_event_dict_[v].pop(cas_id)
print(f"Add a social sensor, sensors num is {len(A)}")
print(f"Anchor id: {s_star}")
print(f"all_cas is {all_cas}")
print(f"TAR_set size: {len(TAR_set)}")
if all_cas >= len(event_user_dict):
user_event_dict_ = copy.deepcopy(user_event_dict)
L_dict_ = copy.deepcopy(L_dict)
for s in A:
L_dict_[s] = 0
all_cas = 0
print(f"All cas has been perception, add extra node")
# break
print(f"Select social sensors finish! Get social sensors, num: {len(A)}")
return A, all_cas
def handle_event(event_dict, obs_time):
user_event_dict = defaultdict(dict)
event_user_dict = defaultdict(list)
cascades_total = 0
# t_min_cas = -1
t_min_all = args.t_min
for cascade_id, event_list in event_dict:
cascades_total += 1
activation_times = {}
t_max = 0
t_min = float('inf')
for u_t_dict in event_list:
time_now = int(u_t_dict['act_time'])
if time_now > t_max:
t_max = time_now
if time_now < t_min:
t_min = time_now
node_id = u_t_dict['user_id']
if time_now > obs_time and obs_time != -1:
continue
if node_id in activation_times.keys():
activation_times[node_id] = min(time_now, activation_times[node_id])
else:
activation_times[node_id] = time_now
print(f"cascade_id:{cascade_id}, t_min:{t_min}")
if len(activation_times) < args.cas_min_length:
print(f"Too short Cas! Cas_id:{cascade_id}")
continue
if t_min > args.t_max or t_min < args.t_min:
print(f"Cas is not in proper time! Cas_id:{cascade_id}")
continue
for k,v in activation_times.items():
event_user_dict[cascade_id].append(k)
if t_max > t_min:
user_event_dict[k][cascade_id] = (v-t_min)/(t_max-t_min)
L_dict = {}
for k,v in user_event_dict.items():
L_dict[k] = len(v.keys())
return L_dict, user_event_dict, event_user_dict
def generate_anchors(event_dict, anchor_budget, obs_time, result_path, anchor_num=-1):
L_dict, user_event_dict, event_user_dict = handle_event(event_dict, obs_time)
print(f"Handle cascade file finish! Users num is {len(L_dict)}")
os.makedirs(result_path, exist_ok=True)
json_file = os.path.join(result_path, f'user_event.json')
try:
with open(json_file, 'w', encoding='utf-8') as f:
json.dump(user_event_dict, f, indent=4, ensure_ascii=False)
print(f"user_event_dict已成功保存到 JSON 文件: {json_file}")
except TypeError as e:
print(f"错误: 字典中可能包含无法序列化为 JSON 的数据类型。错误: {e}")
except IOError:
print(f"错误: 无法打开或写入文件 '{json_file}'")
except Exception as e:
print(f"保存 JSON 文件时发生未知错误: {e}")
if anchor_num == -1:
anchor_num = len(L_dict)*anchor_budget
A, all_cas = select_social_sensors(R, L_dict, user_event_dict, event_user_dict, anchor_num)
os.makedirs(result_path, exist_ok=True)
output_file = os.path.join(result_path, f'anchors_{anchor_num}.txt')
with open(output_file, 'w') as file:
for item in A:
file.write(f"{item}\n")
return A, all_cas, user_event_dict