Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 14 additions & 10 deletions tn/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,6 @@
from pynini.lib import byte, utf8
from pynini.lib.pynutil import delete, insert

logging.basicConfig(level=logging.DEBUG,
format='%(asctime)s WETEXT %(levelname)s %(message)s')


class Processor:

Expand Down Expand Up @@ -78,6 +75,13 @@ def build_verbalizer(self):
self.verbalizer = self.delete_tokens(verbalizer)

def build_fst(self, prefix, cache_dir, overwrite_cache):
logger = logging.getLogger('wetext-{}'.format(self.name))
logger.setLevel(logging.INFO)
handler = logging.StreamHandler()
fmt = logging.Formatter('%(asctime)s WETEXT %(levelname)s %(message)s')
handler.setFormatter(fmt)
logger.addHandler(handler)

os.makedirs(cache_dir, exist_ok=True)
tagger_name = '{}_tagger.fst'.format(prefix)
verbalizer_name = '{}_verbalizer.fst'.format(prefix)
Expand All @@ -88,20 +92,20 @@ def build_fst(self, prefix, cache_dir, overwrite_cache):
exists = os.path.exists(tagger_path) and os.path.exists(
verbalizer_path)
if exists and not overwrite_cache:
logging.info("found existing fst: {}".format(tagger_path))
logging.info(" {}".format(verbalizer_path))
logging.info("skip building fst for {} ...".format(self.name))
logger.info("found existing fst: {}".format(tagger_path))
logger.info(" {}".format(verbalizer_path))
logger.info("skip building fst for {} ...".format(self.name))
self.tagger = Fst.read(tagger_path).optimize()
self.verbalizer = Fst.read(verbalizer_path).optimize()
else:
logging.info("building fst for {} ...".format(self.name))
logger.info("building fst for {} ...".format(self.name))
self.build_tagger()
self.build_verbalizer()
self.tagger.optimize().write(tagger_path)
self.verbalizer.optimize().write(verbalizer_path)
logging.info("done")
logging.info("fst path: {}".format(tagger_path))
logging.info(" {}".format(verbalizer_path))
logger.info("done")
logger.info("fst path: {}".format(tagger_path))
logger.info(" {}".format(verbalizer_path))

def tag(self, input):
if len(input) == 0:
Expand Down