Compare commits

...

2 Commits

Author SHA1 Message Date
032242359a Merge branch 'main' of ssh://172.16.20.1:2222/wanglei/mdbs 2025-07-09 15:57:34 +08:00
a367fe73b2 modified: anchor.py 2025-07-09 15:52:06 +08:00

View File

@ -23,10 +23,11 @@ def find_max_indices_numpy(L_dict):
return max_keys_np_way return max_keys_np_way
def R(event_dict): def R(event_dict):
sum = 0 mean = 0
for _,v in event_dict.items(): for _,v in event_dict.items():
sum += -v mean += -v
return sum mean /= len(event_dict)
return mean
def select_social_sensors(R, L_dict, user_event_dict, event_user_dict, b): def select_social_sensors(R, L_dict, user_event_dict, event_user_dict, b):
@ -65,11 +66,17 @@ def select_social_sensors(R, L_dict, user_event_dict, event_user_dict, b):
s_star = max(delta, key=delta.get) s_star = max(delta, key=delta.get)
c_star = user_event_dict_[s_star] c_star = user_event_dict_[s_star]
if cur[s_star] == True: if cur[s_star] == True:
A.append(s_star) if delta[s_star] >= args.anchor_screen:
A.append(s_star)
break break
else: else:
delta[s_star] = R(c_star) delta[s_star] = R(c_star)
cur[s_star] = True cur[s_star] = True
if delta[s_star] < args.anchor_screen:
L_dict_[s_star] = 0
continue
all_cas += len(c_star) all_cas += len(c_star)
for cas_id in list(c_star.keys()): for cas_id in list(c_star.keys()):
uc = event_user_dict[cas_id] uc = event_user_dict[cas_id]
@ -191,6 +198,7 @@ def generate_anchors(args):
anchor_num = max_anchor_num anchor_num = max_anchor_num
else: else:
anchor_num = int(num_nodes*anchor_budget) anchor_num = int(num_nodes*anchor_budget)
print(f"anchor_num is set to {anchor_num}")
# anchor_num = 10 # anchor_num = 10
A, all_cas = select_social_sensors(R, L_dict, user_event_dict, event_user_dict, anchor_num) A, all_cas = select_social_sensors(R, L_dict, user_event_dict, event_user_dict, anchor_num)
@ -529,10 +537,12 @@ def parse_args():
parser = argparse.ArgumentParser(description='Parameters') parser = argparse.ArgumentParser(description='Parameters')
parser.add_argument('--input_file', default='./dataset_for_anchor.txt', type=str, help='Cascade file') parser.add_argument('--input_file', default='./dataset_for_anchor.txt', type=str, help='Cascade file')
parser.add_argument('--result_path', default='./result/', type=str, help='result save path') parser.add_argument('--result_path', default='./result/', type=str, help='result save path')
parser.add_argument('--anchor_budget', default=0.0002, type=float, help='Anchors num') parser.add_argument('--anchor_budget', default=0.005, type=float, help='Anchors num')
parser.add_argument('--obs_cas_num', default=-1, type=int, help='Anchors observe cas num') parser.add_argument('--obs_cas_num', default=2000, type=int, help='Anchors observe cas num')
parser.add_argument('--obs_time', default=-1, type=int, help='Anchors observe time, default seeting is -1, meaning can observe all') parser.add_argument('--obs_time', default=-1, type=int, help='Anchors observe time, default seeting is -1, meaning can observe all')
parser.add_argument('--key_node_file', default='/home/zjy/project/CasDO/dataset/nh/key_node_10', type=str, help='key_node_json save file') parser.add_argument('--key_node_file', default='/home/zjy/project/CasDO/dataset/nh/key_node_10', type=str, help='key_node_json save file')
parser.add_argument('--anchor_screen', default=-0.4, type=float, help='anchor_screen')
return parser.parse_args() return parser.parse_args()
if __name__ == '__main__': if __name__ == '__main__':