78 lines
		
	
	
		
			3.4 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			78 lines
		
	
	
		
			3.4 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
import argparse
 | 
						|
from collections import defaultdict
 | 
						|
import json
 | 
						|
import os
 | 
						|
from model import anchor_finder
 | 
						|
 | 
						|
def handle_cas(args):
 | 
						|
    filename, result_path = args.input_file, args.result_path
 | 
						|
    print("Handle cascade begin!")
 | 
						|
    event_dict = defaultdict(dict)
 | 
						|
 | 
						|
    with open(filename) as file:
 | 
						|
        for line in file:
 | 
						|
            cascades_total += 1
 | 
						|
 | 
						|
            parts = line.split(',')
 | 
						|
            cascade_id = parts[0]
 | 
						|
            
 | 
						|
            paths = parts[1:]
 | 
						|
 | 
						|
            for p in paths:
 | 
						|
                # observed adoption/participant
 | 
						|
                n_t = p.split(':')
 | 
						|
                if len(n_t) < 2:
 | 
						|
                    continue
 | 
						|
                nodes = n_t[0].split('/')
 | 
						|
                time_now = n_t[1]
 | 
						|
 | 
						|
                node = nodes[-1]
 | 
						|
                node_id = node
 | 
						|
 | 
						|
                u_t_dict = {
 | 
						|
                    'user_id': node_id,
 | 
						|
                    'act_time': time_now
 | 
						|
                }
 | 
						|
                event_dict[cascade_id].append(u_t_dict)        
 | 
						|
            print(f"cascade_id:{cascade_id}")
 | 
						|
 | 
						|
    print(f"Handle cascade file finish!")
 | 
						|
    os.makedirs(result_path, exist_ok=True)
 | 
						|
    json_file = os.path.join(result_path, f'event_dict.json')
 | 
						|
    try:
 | 
						|
        with open(json_file, 'w', encoding='utf-8') as f:
 | 
						|
            json.dump(event_dict, f, indent=4, ensure_ascii=False)
 | 
						|
        print(f"user_event_dict已成功保存到 JSON 文件: {json_file}")
 | 
						|
    except TypeError as e:
 | 
						|
        print(f"错误: 字典中可能包含无法序列化为 JSON 的数据类型。错误: {e}")
 | 
						|
    except IOError:
 | 
						|
        print(f"错误: 无法打开或写入文件 '{json_file}'。")
 | 
						|
    except Exception as e:
 | 
						|
        print(f"保存 JSON 文件时发生未知错误: {e}")
 | 
						|
        
 | 
						|
    return event_dict
 | 
						|
 | 
						|
def parse_args():
 | 
						|
    parser = argparse.ArgumentParser(description='Parameters')
 | 
						|
    parser.add_argument('--input_file', default='./dataset_for_anchor.txt', type=str, help='Cascade file')
 | 
						|
    parser.add_argument('--result_path', default='./result/', type=str, help='result save path')
 | 
						|
    parser.add_argument('--anchor_budget', default=0.005, type=float, help='Anchors proportion')
 | 
						|
    parser.add_argument('--anchor_num', default=1000, type=int, help='Anchors num')
 | 
						|
    parser.add_argument('--obs_cas_num', default=-1, type=int, help='Anchors observe cas num')
 | 
						|
    parser.add_argument('--obs_time', default=-1, type=int, help='Anchors observe time, default seeting is -1, meaning can observe all')
 | 
						|
    parser.add_argument('--key_node_file', default='/home/zjy/project/CasDO/dataset/nh/key_node_10', type=str, help='key_node_json save file')
 | 
						|
    parser.add_argument('--anchor_screen', default=-0.8, type=float, help='anchor_screen')
 | 
						|
    parser.add_argument('--cas_min_length', default=20, type=int, help='Cas min length')
 | 
						|
    parser.add_argument('--t_max', default=1751299200, type=int, help='Cas min length')
 | 
						|
    parser.add_argument('--t_min', default=1751299200, type=int, help='Cas min length')
 | 
						|
    return parser.parse_args()
 | 
						|
 | 
						|
if __name__ == '__main__':
 | 
						|
    args = parse_args()
 | 
						|
 | 
						|
    input_file, result_path, anchor_budget, obs_time, anchor_num = args.input_file, args.result_path, args.anchor_budget, args.obs_time, args.anchor_num
 | 
						|
    event_dict = handle_cas(args)
 | 
						|
    print(f"anchor_num is set to {anchor_num}")
 | 
						|
    
 | 
						|
    #===========执行m点挖掘算法, 返回m点以及感知的级联数量=================
 | 
						|
    A, all_cas, user_event_dict = anchor_finder.generate_anchors(event_dict, anchor_budget, obs_time, result_path, anchor_num) |