renamed: dataset_for_anchor.txt -> data/dataset_for_anchor.txt modified: main.py modified: model/anchor_finder.py
		
			
				
	
	
		
			74 lines
		
	
	
		
			3.0 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			74 lines
		
	
	
		
			3.0 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
import argparse
 | 
						||
from collections import defaultdict
 | 
						||
import json
 | 
						||
import os
 | 
						||
from model import anchor_finder
 | 
						||
 | 
						||
def handle_cas(args):
 | 
						||
    filename, result_path = args.input_file, args.result_path
 | 
						||
    print("Handle cascade begin!")
 | 
						||
    event_dict = defaultdict(list)
 | 
						||
    cascades_total = 0
 | 
						||
    with open(filename) as file:
 | 
						||
        for line in file:
 | 
						||
            cascades_total += 1
 | 
						||
            if cascades_total > args.obs_cas_num and args.obs_cas_num >0:
 | 
						||
                break
 | 
						||
            parts = line.split(',')
 | 
						||
            cascade_id = parts[0]
 | 
						||
            
 | 
						||
            paths = parts[1:]
 | 
						||
 | 
						||
            for p in paths:
 | 
						||
                # observed adoption/participant
 | 
						||
                n_t = p.split(':')
 | 
						||
                if len(n_t) < 2:
 | 
						||
                    continue
 | 
						||
                nodes = n_t[0].split('/')
 | 
						||
                time_now = n_t[1]
 | 
						||
 | 
						||
                node = nodes[-1]
 | 
						||
                node_id = node
 | 
						||
 | 
						||
                u_t_dict = {
 | 
						||
                    'user_id': node_id,
 | 
						||
                    'act_time': time_now
 | 
						||
                }
 | 
						||
                event_dict[cascade_id].append(u_t_dict)        
 | 
						||
            print(f"cascade_id:{cascade_id}")
 | 
						||
 | 
						||
    print(f"Handle cascade file finish!")
 | 
						||
    os.makedirs(result_path, exist_ok=True)
 | 
						||
    json_file = os.path.join(result_path, f'event_dict.json')
 | 
						||
    try:
 | 
						||
        with open(json_file, 'w', encoding='utf-8') as f:
 | 
						||
            json.dump(event_dict, f, indent=4, ensure_ascii=False)
 | 
						||
        print(f"user_event_dict已成功保存到 JSON 文件: {json_file}")
 | 
						||
    except TypeError as e:
 | 
						||
        print(f"错误: 字典中可能包含无法序列化为 JSON 的数据类型。错误: {e}")
 | 
						||
    except IOError:
 | 
						||
        print(f"错误: 无法打开或写入文件 '{json_file}'。")
 | 
						||
    except Exception as e:
 | 
						||
        print(f"保存 JSON 文件时发生未知错误: {e}")
 | 
						||
        
 | 
						||
    return event_dict
 | 
						||
 | 
						||
def parse_args():
 | 
						||
    parser = argparse.ArgumentParser(description='Parameters')
 | 
						||
    parser.add_argument('--input_file', default='./data/dataset_for_anchor.txt', type=str, help='Cascade file')
 | 
						||
    parser.add_argument('--result_path', default='./result/', type=str, help='result save path')
 | 
						||
    parser.add_argument('--anchor_budget', default=0.005, type=float, help='Anchors proportion')
 | 
						||
    parser.add_argument('--anchor_num', default=10, type=int, help='Anchors num')
 | 
						||
    parser.add_argument('--obs_cas_num', default=200, type=int, help='Anchors observe cas num')
 | 
						||
    parser.add_argument('--obs_time', default=-1, type=int, help='Anchors observe time, default seeting is -1, meaning can observe all')
 | 
						||
    parser.add_argument('--anchor_screen', default=1, type=float, help='anchor_screen')
 | 
						||
    return parser.parse_args()
 | 
						||
 | 
						||
if __name__ == '__main__':
 | 
						||
    args = parse_args()
 | 
						||
 | 
						||
    # 该demo中给出生产event_dict示例,生产中须自行构建event_dict
 | 
						||
    event_dict = handle_cas(args)
 | 
						||
    
 | 
						||
    #===========执行m点挖掘算法, 返回m点以及感知的级联数量=================
 | 
						||
    A, all_cas, user_event_dict = anchor_finder.generate_anchors(event_dict, args.anchor_budget, args.anchor_screen, args.obs_time, args.result_path, args.anchor_num) |