55"""
66import csv
77import ast
8- from typing import NamedTuple , Dict
8+ import asyncio
9+ from typing import NamedTuple , Dict , List
10+ from dataclasses import dataclass
11+ from contextlib import asynccontextmanager
912
1013from ..repo import Repo
1114from .memory import MemorySource
1619csv .register_dialect ("strip" , skipinitialspace = True )
1720
1821
22+ @dataclass
23+ class OpenCSVFile :
24+ write_out : Dict
25+ active : int
26+ lock : asyncio .Lock
27+ write_back_key : bool = True
28+ write_back_label : bool = False
29+
30+ async def inc (self ):
31+ async with self .lock :
32+ self .active += 1
33+
34+ async def dec (self ):
35+ async with self .lock :
36+ self .active -= 1
37+ return bool (self .active < 1 )
38+
39+
40+ CSV_SOURCE_CONFIG_DEFAULT_KEY = "src_url"
41+ CSV_SOURCE_CONFIG_DEFAULT_LABEL = "unlabeled"
42+ CSV_SOURCE_CONFIG_DEFAULT_LABEL_COLUMN = "label"
43+
44+
1945class CSVSourceConfig (FileSourceConfig , NamedTuple ):
2046 filename : str
21- label : str = "unlabeled"
2247 readonly : bool = False
23- key : str = None
48+ key : str = CSV_SOURCE_CONFIG_DEFAULT_KEY
49+ label : str = CSV_SOURCE_CONFIG_DEFAULT_LABEL
50+ label_column : str = CSV_SOURCE_CONFIG_DEFAULT_LABEL_COLUMN
2451
2552
53+ # CSVSource is a bit of a mess
2654@entry_point ("csv" )
2755class CSVSource (FileSource , MemorySource ):
2856 """
@@ -32,6 +60,29 @@ class CSVSource(FileSource, MemorySource):
3260 # Headers we've added to track data other than feature data for a repo
3361 CSV_HEADERS = ["prediction" , "confidence" ]
3462
63+ OPEN_CSV_FILES : Dict [str , OpenCSVFile ] = {}
64+ OPEN_CSV_FILES_LOCK : asyncio .Lock = asyncio .Lock ()
65+
66+ @asynccontextmanager
67+ async def _open_csv (self , fd = None ):
68+ async with self .OPEN_CSV_FILES_LOCK :
69+ if self .config .filename not in self .OPEN_CSV_FILES :
70+ self .logger .debug (f"{ self .config .filename } first open" )
71+ open_file = OpenCSVFile (
72+ active = 1 , lock = asyncio .Lock (), write_out = {}
73+ )
74+ self .OPEN_CSV_FILES [self .config .filename ] = open_file
75+ if fd is not None :
76+ await self .read_csv (fd , open_file )
77+ else :
78+ self .logger .debug (f"{ self .config .filename } already open" )
79+ await self .OPEN_CSV_FILES [self .config .filename ].inc ()
80+ yield self .OPEN_CSV_FILES [self .config .filename ]
81+
82+ async def _empty_file_init (self ):
83+ async with self ._open_csv ():
84+ return {}
85+
3586 @classmethod
3687 def args (cls , args , * above ) -> Dict [str , Arg ]:
3788 cls .config_set (args , above , "filename" , Arg ())
@@ -42,9 +93,18 @@ def args(cls, args, *above) -> Dict[str, Arg]:
4293 Arg (type = bool , action = "store_true" , default = False ),
4394 )
4495 cls .config_set (
45- args , above , "label" , Arg (type = str , default = "unlabeled" )
96+ args ,
97+ above ,
98+ "label" ,
99+ Arg (type = str , default = CSV_SOURCE_CONFIG_DEFAULT_LABEL ),
46100 )
47- cls .config_set (args , above , "key" , Arg (type = str , default = None ))
101+ cls .config_set (
102+ args ,
103+ above ,
104+ "labelcol" ,
105+ Arg (type = str , default = CSV_SOURCE_CONFIG_DEFAULT_LABEL_COLUMN ),
106+ )
107+ cls .config_set (args , above , "key" , Arg (type = str , default = "src_url" ))
48108 return args
49109
50110 @classmethod
@@ -54,38 +114,53 @@ def config(cls, config, *above):
54114 readonly = cls .config_get (config , above , "readonly" ),
55115 label = cls .config_get (config , above , "label" ),
56116 key = cls .config_get (config , above , "key" ),
117+ label_column = cls .config_get (config , above , "labelcol" ),
57118 )
58119
59- async def load_fd (self , fd ):
60- """
61- Parses a CSV stream into Repo instances
62- """
63- i = 0
64- self .mem = {}
65- for data in csv .DictReader (fd , dialect = "strip" ):
120+ async def read_csv (self , fd , open_file ):
121+ dict_reader = csv .DictReader (fd , dialect = "strip" )
122+ # Record what headers are present when the file was opened
123+ if not self .config .key in dict_reader .fieldnames :
124+ open_file .write_back_key = False
125+ if self .config .label_column in dict_reader .fieldnames :
126+ open_file .write_back_label = True
127+ # Store all the repos by their label in write_out
128+ open_file .write_out = {}
129+ # If there is no key track row index to be used as src_url by label
130+ index = {}
131+ for row in dict_reader :
132+ # Grab label from row
133+ label = row .get (self .config .label_column , self .config .label )
134+ if self .config .label_column in row :
135+ del row [self .config .label_column ]
136+ index .setdefault (label , 0 )
137+ # Grab src_url from row
138+ src_url = row .get (self .config .key , index [label ])
139+ if self .config .key in row :
140+ del row [self .config .key ]
141+ else :
142+ index [label ] += 1
66143 # Repo data we are going to parse from this row (must include
67144 # features).
68- repo_data = {"features" : {} }
145+ repo_data = {}
69146 # Parse headers we as the CSV source added
70147 csv_meta = {}
71148 for header in self .CSV_HEADERS :
72- if not data .get (header ) is None and data [header ] != "" :
73- csv_meta [header ] = data [header ]
149+ value = row .get (header , None )
150+ if value is not None and value != "" :
151+ csv_meta [header ] = row [header ]
74152 # Remove from feature data
75- del data [header ]
76- # Parse feature data
77- for key , value in data .items ():
153+ del row [header ]
154+ # Set the features
155+ features = {}
156+ for key , value in row .items ():
78157 if value != "" :
79158 try :
80- repo_data [ " features" ] [key ] = ast .literal_eval (value )
159+ features [key ] = ast .literal_eval (value )
81160 except (SyntaxError , ValueError ):
82- repo_data ["features" ][key ] = value
83- if self .config .key is not None and self .config .key == key :
84- src_url = value
85- if self .config .key is None :
86- src_url = str (i )
87- i += 1
88- # Correct types and structure of repo data from csv_meta
161+ features [key ] = value
162+ if features :
163+ repo_data ["features" ] = features
89164 if "prediction" in csv_meta and "confidence" in csv_meta :
90165 repo_data .update (
91166 {
@@ -95,32 +170,67 @@ async def load_fd(self, fd):
95170 }
96171 }
97172 )
98- repo = Repo (src_url , data = repo_data )
99- self .mem [repo .src_url ] = repo
173+ # If there was no data in the row, skip it
174+ if not repo_data and src_url == str (index [label ] - 1 ):
175+ continue
176+ # Add the repo to our internal memory representation
177+ open_file .write_out .setdefault (label , {})
178+ open_file .write_out [label ][src_url ] = Repo (src_url , data = repo_data )
179+
180+ async def load_fd (self , fd ):
181+ """
182+ Parses a CSV stream into Repo instances
183+ """
184+ async with self ._open_csv (fd ) as open_file :
185+ self .mem = open_file .write_out .get (self .config .label , {})
100186 self .logger .debug ("%r loaded %d records" , self , len (self .mem ))
101187
102188 async def dump_fd (self , fd ):
103189 """
104190 Dumps data into a CSV stream
105191 """
106- # Sample some headers without iterating all the way through
107- fieldnames = []
108- for repo in self .mem .values ():
109- fieldnames = list (repo .data .features .keys ())
110- break
111- # Add our headers
112- fieldnames += self .CSV_HEADERS
113- # Write out the file
114- writer = csv .DictWriter (fd , fieldnames = fieldnames )
115- writer .writeheader ()
116- # Write out rows in order
117- for repo in self .mem .values ():
118- repo_data = repo .dict ()
119- row = {}
120- for key , value in repo_data ["features" ].items ():
121- row [key ] = value
122- if "prediction" in repo_data :
123- row ["prediction" ] = repo_data ["prediction" ]["value" ]
124- row ["confidence" ] = repo_data ["prediction" ]["confidence" ]
125- writer .writerow (row )
192+ async with self .OPEN_CSV_FILES_LOCK :
193+ open_file = self .OPEN_CSV_FILES [self .config .filename ]
194+ open_file .write_out .setdefault (self .config .label , {})
195+ open_file .write_out [self .config .label ].update (self .mem )
196+ # Bail if not last open source for this file
197+ if not (await open_file .dec ()):
198+ return
199+ # Add our headers
200+ fieldnames = (
201+ [] if not open_file .write_back_key else [self .config .key ]
202+ )
203+ fieldnames .append (self .config .label_column )
204+ # Get all the feature names
205+ feature_fieldnames = set ()
206+ for label , repos in open_file .write_out .items ():
207+ for repo in repos .values ():
208+ feature_fieldnames |= set (repo .data .features .keys ())
209+ fieldnames += list (feature_fieldnames )
210+ fieldnames += self .CSV_HEADERS
211+ self .logger .debug (f"fieldnames: { fieldnames } " )
212+ # Write out the file
213+ writer = csv .DictWriter (fd , fieldnames = fieldnames )
214+ writer .writeheader ()
215+ for label , repos in open_file .write_out .items ():
216+ for repo in repos .values ():
217+ repo_data = repo .dict ()
218+ row = {name : "" for name in fieldnames }
219+ # Always write the label
220+ row [self .config .label_column ] = label
221+ # Write the key if it existed
222+ if open_file .write_back_key :
223+ row [self .config .key ] = repo .src_url
224+ # Write the features
225+ for key , value in repo_data .get ("features" , {}).items ():
226+ row [key ] = value
227+ # Write the prediction
228+ if "prediction" in repo_data :
229+ row ["prediction" ] = repo_data ["prediction" ]["value" ]
230+ row ["confidence" ] = repo_data ["prediction" ][
231+ "confidence"
232+ ]
233+ writer .writerow (row )
234+ del self .OPEN_CSV_FILES [self .config .filename ]
235+ self .logger .debug (f"{ self .config .filename } written" )
126236 self .logger .debug ("%r saved %d records" , self , len (self .mem ))
0 commit comments