new file: main.py
new file: model/anchor_finder.py
This commit is contained in:
		
							parent
							
								
									75d72a8c6c
								
							
						
					
					
						commit
						1f72b0e6c9
					
				
							
								
								
									
										78
									
								
								main.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										78
									
								
								main.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,78 @@
 | 
			
		|||
import argparse
 | 
			
		||||
from collections import defaultdict
 | 
			
		||||
import json
 | 
			
		||||
import os
 | 
			
		||||
from model import anchor_finder
 | 
			
		||||
 | 
			
		||||
def handle_cas(args):
 | 
			
		||||
    filename, result_path = args.input_file, args.result_path
 | 
			
		||||
    print("Handle cascade begin!")
 | 
			
		||||
    event_dict = defaultdict(dict)
 | 
			
		||||
 | 
			
		||||
    with open(filename) as file:
 | 
			
		||||
        for line in file:
 | 
			
		||||
            cascades_total += 1
 | 
			
		||||
 | 
			
		||||
            parts = line.split(',')
 | 
			
		||||
            cascade_id = parts[0]
 | 
			
		||||
            
 | 
			
		||||
            paths = parts[1:]
 | 
			
		||||
 | 
			
		||||
            for p in paths:
 | 
			
		||||
                # observed adoption/participant
 | 
			
		||||
                n_t = p.split(':')
 | 
			
		||||
                if len(n_t) < 2:
 | 
			
		||||
                    continue
 | 
			
		||||
                nodes = n_t[0].split('/')
 | 
			
		||||
                time_now = n_t[1]
 | 
			
		||||
 | 
			
		||||
                node = nodes[-1]
 | 
			
		||||
                node_id = node
 | 
			
		||||
 | 
			
		||||
                u_t_dict = {
 | 
			
		||||
                    'user_id': node_id,
 | 
			
		||||
                    'act_time': time_now
 | 
			
		||||
                }
 | 
			
		||||
                event_dict[cascade_id].append(u_t_dict)        
 | 
			
		||||
            print(f"cascade_id:{cascade_id}")
 | 
			
		||||
 | 
			
		||||
    print(f"Handle cascade file finish!")
 | 
			
		||||
    os.makedirs(result_path, exist_ok=True)
 | 
			
		||||
    json_file = os.path.join(result_path, f'event_dict.json')
 | 
			
		||||
    try:
 | 
			
		||||
        with open(json_file, 'w', encoding='utf-8') as f:
 | 
			
		||||
            json.dump(event_dict, f, indent=4, ensure_ascii=False)
 | 
			
		||||
        print(f"user_event_dict已成功保存到 JSON 文件: {json_file}")
 | 
			
		||||
    except TypeError as e:
 | 
			
		||||
        print(f"错误: 字典中可能包含无法序列化为 JSON 的数据类型。错误: {e}")
 | 
			
		||||
    except IOError:
 | 
			
		||||
        print(f"错误: 无法打开或写入文件 '{json_file}'。")
 | 
			
		||||
    except Exception as e:
 | 
			
		||||
        print(f"保存 JSON 文件时发生未知错误: {e}")
 | 
			
		||||
        
 | 
			
		||||
    return event_dict
 | 
			
		||||
 | 
			
		||||
def parse_args():
 | 
			
		||||
    parser = argparse.ArgumentParser(description='Parameters')
 | 
			
		||||
    parser.add_argument('--input_file', default='./dataset_for_anchor.txt', type=str, help='Cascade file')
 | 
			
		||||
    parser.add_argument('--result_path', default='./result/', type=str, help='result save path')
 | 
			
		||||
    parser.add_argument('--anchor_budget', default=0.005, type=float, help='Anchors proportion')
 | 
			
		||||
    parser.add_argument('--anchor_num', default=1000, type=int, help='Anchors num')
 | 
			
		||||
    parser.add_argument('--obs_cas_num', default=-1, type=int, help='Anchors observe cas num')
 | 
			
		||||
    parser.add_argument('--obs_time', default=-1, type=int, help='Anchors observe time, default seeting is -1, meaning can observe all')
 | 
			
		||||
    parser.add_argument('--key_node_file', default='/home/zjy/project/CasDO/dataset/nh/key_node_10', type=str, help='key_node_json save file')
 | 
			
		||||
    parser.add_argument('--anchor_screen', default=-0.8, type=float, help='anchor_screen')
 | 
			
		||||
    parser.add_argument('--cas_min_length', default=20, type=int, help='Cas min length')
 | 
			
		||||
    parser.add_argument('--t_max', default=1751299200, type=int, help='Cas min length')
 | 
			
		||||
    parser.add_argument('--t_min', default=1751299200, type=int, help='Cas min length')
 | 
			
		||||
    return parser.parse_args()
 | 
			
		||||
 | 
			
		||||
if __name__ == '__main__':
 | 
			
		||||
    args = parse_args()
 | 
			
		||||
 | 
			
		||||
    input_file, result_path, anchor_budget, obs_time, anchor_num = args.input_file, args.result_path, args.anchor_budget, args.obs_time, args.anchor_num
 | 
			
		||||
    event_dict = handle_cas(args)
 | 
			
		||||
    print(f"anchor_num is set to {anchor_num}")
 | 
			
		||||
    
 | 
			
		||||
    #===========执行m点挖掘算法, 返回m点以及感知的级联数量=================
 | 
			
		||||
    A, all_cas, user_event_dict = anchor_finder.generate_anchors(event_dict, anchor_budget, obs_time, result_path, anchor_num)
 | 
			
		||||
							
								
								
									
										186
									
								
								model/anchor_finder.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										186
									
								
								model/anchor_finder.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,186 @@
 | 
			
		|||
from collections import defaultdict
 | 
			
		||||
import numpy as np
 | 
			
		||||
import copy
 | 
			
		||||
import argparse
 | 
			
		||||
import json
 | 
			
		||||
import numpy as np
 | 
			
		||||
import os
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def find_max_indices_numpy(L_dict):    
 | 
			
		||||
    keys_arr = np.array(list(L_dict.keys()))
 | 
			
		||||
    values_arr = np.array(list(L_dict.values()))
 | 
			
		||||
    max_val_from_dict = max(L_dict.values()) # 或者 np.max(values_arr)
 | 
			
		||||
 | 
			
		||||
    indices = np.where(values_arr == max_val_from_dict)[0]
 | 
			
		||||
    max_keys_np_way = keys_arr[indices]
 | 
			
		||||
    return max_keys_np_way, max_val_from_dict
 | 
			
		||||
 | 
			
		||||
def R(event_dict):
 | 
			
		||||
    mean = 0
 | 
			
		||||
    for _,v in event_dict.items():
 | 
			
		||||
        mean += -v
 | 
			
		||||
    mean /= len(event_dict)
 | 
			
		||||
    return mean
 | 
			
		||||
    
 | 
			
		||||
 | 
			
		||||
def select_social_sensors(R, L_dict, user_event_dict, event_user_dict, b):
 | 
			
		||||
    """
 | 
			
		||||
    Parameters:
 | 
			
		||||
    - R: reward function, R(A) returns a numeric value
 | 
			
		||||
    - L_dict:  dict of node -> integer (event num)
 | 
			
		||||
    - user_event_dict: user -> event -> active_time
 | 
			
		||||
    - event_user_dict: event -> user_list
 | 
			
		||||
    - b: budget (numeric)
 | 
			
		||||
 | 
			
		||||
    Returns:
 | 
			
		||||
    - A: selected set of social sensors
 | 
			
		||||
    """
 | 
			
		||||
    print("Select social sensors begin!")
 | 
			
		||||
    user_event_dict_ = copy.deepcopy(user_event_dict)
 | 
			
		||||
    L_dict_ = copy.deepcopy(L_dict)
 | 
			
		||||
    
 | 
			
		||||
    V = list(L_dict_.keys())
 | 
			
		||||
    A = []
 | 
			
		||||
    f = lambda A: len(A)
 | 
			
		||||
    all_cas = 0
 | 
			
		||||
    while any(s not in A for s in V) and f(A) < b:
 | 
			
		||||
        
 | 
			
		||||
        indices_np, max_val_from_dict = find_max_indices_numpy(L_dict_)
 | 
			
		||||
        if max_val_from_dict == 0:
 | 
			
		||||
            print("no extra node in cas!!!!!")
 | 
			
		||||
            break
 | 
			
		||||
        TAR_set = set(indices_np)
 | 
			
		||||
 | 
			
		||||
        delta = {}
 | 
			
		||||
        cur = {}
 | 
			
		||||
        for s in TAR_set:
 | 
			
		||||
            delta[s] = float('inf')
 | 
			
		||||
            cur[s] = False
 | 
			
		||||
 | 
			
		||||
        c_star = []
 | 
			
		||||
        while True:
 | 
			
		||||
            s_star = max(delta, key=delta.get)
 | 
			
		||||
            c_star = user_event_dict_[s_star]
 | 
			
		||||
            if cur[s_star] == True:
 | 
			
		||||
                if delta[s_star] >= args.anchor_screen:
 | 
			
		||||
                    A.append(s_star)
 | 
			
		||||
                break
 | 
			
		||||
            else: 
 | 
			
		||||
                delta[s_star] = R(c_star)
 | 
			
		||||
                cur[s_star] = True
 | 
			
		||||
 | 
			
		||||
        if delta[s_star] < args.anchor_screen:
 | 
			
		||||
            L_dict_[s_star] = 0
 | 
			
		||||
            continue
 | 
			
		||||
 | 
			
		||||
        all_cas += len(c_star)
 | 
			
		||||
        for cas_id in list(c_star.keys()):
 | 
			
		||||
            uc = event_user_dict[cas_id]
 | 
			
		||||
            for v in uc:
 | 
			
		||||
                if v in L_dict_.keys():
 | 
			
		||||
                    L_dict_[v] -= 1              
 | 
			
		||||
                _ = user_event_dict_[v].pop(cas_id)
 | 
			
		||||
        print(f"Add a social sensor, sensors num is {len(A)}")
 | 
			
		||||
        print(f"Anchor id: {s_star}")
 | 
			
		||||
        print(f"all_cas is {all_cas}")
 | 
			
		||||
        print(f"TAR_set size: {len(TAR_set)}")
 | 
			
		||||
        if all_cas >= len(event_user_dict):
 | 
			
		||||
            user_event_dict_ = copy.deepcopy(user_event_dict)
 | 
			
		||||
            L_dict_ = copy.deepcopy(L_dict)
 | 
			
		||||
            for s in A:
 | 
			
		||||
                L_dict_[s] = 0
 | 
			
		||||
            all_cas = 0
 | 
			
		||||
            print(f"All cas has been perception, add extra node")
 | 
			
		||||
            # break
 | 
			
		||||
        
 | 
			
		||||
    
 | 
			
		||||
    print(f"Select social sensors finish! Get social sensors, num: {len(A)}")
 | 
			
		||||
    return A, all_cas
 | 
			
		||||
 | 
			
		||||
def handle_event(event_dict, obs_time):
 | 
			
		||||
 | 
			
		||||
    user_event_dict = defaultdict(dict)
 | 
			
		||||
    event_user_dict = defaultdict(list)
 | 
			
		||||
    cascades_total = 0
 | 
			
		||||
    
 | 
			
		||||
    # t_min_cas = -1
 | 
			
		||||
    t_min_all = args.t_min
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    for cascade_id, event_list in event_dict:
 | 
			
		||||
 | 
			
		||||
        cascades_total += 1
 | 
			
		||||
 | 
			
		||||
        activation_times = {}
 | 
			
		||||
 | 
			
		||||
        t_max = 0
 | 
			
		||||
        t_min = float('inf')
 | 
			
		||||
        for u_t_dict in event_list:
 | 
			
		||||
 | 
			
		||||
            time_now = int(u_t_dict['act_time'])
 | 
			
		||||
            if time_now > t_max:
 | 
			
		||||
                t_max = time_now
 | 
			
		||||
            if time_now < t_min:
 | 
			
		||||
                t_min = time_now
 | 
			
		||||
 | 
			
		||||
            node_id = u_t_dict['user_id']
 | 
			
		||||
            
 | 
			
		||||
            if time_now > obs_time and obs_time != -1:
 | 
			
		||||
                continue
 | 
			
		||||
 | 
			
		||||
            if node_id in activation_times.keys():
 | 
			
		||||
                activation_times[node_id] = min(time_now, activation_times[node_id])
 | 
			
		||||
            else:
 | 
			
		||||
                activation_times[node_id] = time_now
 | 
			
		||||
        
 | 
			
		||||
        print(f"cascade_id:{cascade_id}, t_min:{t_min}")
 | 
			
		||||
        
 | 
			
		||||
        if len(activation_times) < args.cas_min_length:
 | 
			
		||||
            print(f"Too short Cas! Cas_id:{cascade_id}")
 | 
			
		||||
            continue
 | 
			
		||||
        if t_min > args.t_max or t_min < args.t_min:
 | 
			
		||||
            print(f"Cas is not in proper time! Cas_id:{cascade_id}")
 | 
			
		||||
            continue
 | 
			
		||||
        
 | 
			
		||||
        for k,v in activation_times.items():
 | 
			
		||||
            event_user_dict[cascade_id].append(k)  
 | 
			
		||||
            if t_max > t_min:
 | 
			
		||||
                user_event_dict[k][cascade_id] = (v-t_min)/(t_max-t_min)
 | 
			
		||||
                            
 | 
			
		||||
    L_dict = {}
 | 
			
		||||
 | 
			
		||||
    for k,v in user_event_dict.items():
 | 
			
		||||
        L_dict[k] = len(v.keys())
 | 
			
		||||
          
 | 
			
		||||
    return L_dict, user_event_dict, event_user_dict
 | 
			
		||||
 | 
			
		||||
def generate_anchors(event_dict, anchor_budget, obs_time, result_path, anchor_num=-1):
 | 
			
		||||
    
 | 
			
		||||
    L_dict, user_event_dict, event_user_dict = handle_event(event_dict, obs_time)
 | 
			
		||||
 | 
			
		||||
    print(f"Handle cascade file finish! Users num is {len(L_dict)}")
 | 
			
		||||
    os.makedirs(result_path, exist_ok=True)
 | 
			
		||||
    json_file = os.path.join(result_path, f'user_event.json')
 | 
			
		||||
    try:
 | 
			
		||||
        with open(json_file, 'w', encoding='utf-8') as f:
 | 
			
		||||
            json.dump(user_event_dict, f, indent=4, ensure_ascii=False)
 | 
			
		||||
        print(f"user_event_dict已成功保存到 JSON 文件: {json_file}")
 | 
			
		||||
    except TypeError as e:
 | 
			
		||||
        print(f"错误: 字典中可能包含无法序列化为 JSON 的数据类型。错误: {e}")
 | 
			
		||||
    except IOError:
 | 
			
		||||
        print(f"错误: 无法打开或写入文件 '{json_file}'。")
 | 
			
		||||
    except Exception as e:
 | 
			
		||||
        print(f"保存 JSON 文件时发生未知错误: {e}")   
 | 
			
		||||
 | 
			
		||||
    if anchor_num == -1:
 | 
			
		||||
        anchor_num = len(L_dict)*anchor_budget
 | 
			
		||||
    
 | 
			
		||||
    A, all_cas = select_social_sensors(R, L_dict, user_event_dict, event_user_dict, anchor_num)
 | 
			
		||||
    os.makedirs(result_path, exist_ok=True)
 | 
			
		||||
    output_file = os.path.join(result_path, f'anchors_{anchor_num}.txt')
 | 
			
		||||
    with open(output_file, 'w') as file:
 | 
			
		||||
        for item in A:
 | 
			
		||||
            file.write(f"{item}\n")
 | 
			
		||||
    
 | 
			
		||||
    return A, all_cas, user_event_dict
 | 
			
		||||
		Loading…
	
		Reference in New Issue
	
	Block a user