2626from sqlalchemy .orm import backref , foreign , relationship
2727from sqlalchemy .orm .session import make_transient
2828
29- from airflow .api_internal .internal_api_call import internal_api_call
3029from airflow .configuration import conf
3130from airflow .exceptions import AirflowException
3231from airflow .executors .executor_loader import ExecutorLoader
3332from airflow .listeners .listener import get_listener_manager
3433from airflow .models .base import ID_LEN , Base
35- from airflow .serialization .pydantic .job import JobPydantic
3634from airflow .stats import Stats
3735from airflow .traces .tracer import Trace , add_span
3836from airflow .utils import timezone
3937from airflow .utils .helpers import convert_camel_to_snake
4038from airflow .utils .log .logging_mixin import LoggingMixin
4139from airflow .utils .net import get_hostname
4240from airflow .utils .platform import getuser
43- from airflow .utils .retries import retry_db_transaction
44- from airflow .utils .session import NEW_SESSION , provide_session
41+ from airflow .utils .session import NEW_SESSION , create_session , provide_session
4542from airflow .utils .sqlalchemy import UtcDateTime
4643from airflow .utils .state import JobState
4744
@@ -168,7 +165,10 @@ def kill(self, session: Session = NEW_SESSION) -> NoReturn:
168165 except Exception as e :
169166 self .log .error ("on_kill() method failed: %s" , e )
170167
171- Job ._kill (job_id = self .id , session = session )
168+ job = session .scalar (select (Job ).where (Job .id == self .id , session = session ).limit (1 ))
169+ job .end_date = timezone .utcnow ()
170+ session .merge (job )
171+ session .commit ()
172172 raise AirflowException ("Job shut down externally." )
173173
174174 def on_kill (self ):
@@ -201,7 +201,7 @@ def heartbeat(
201201 try :
202202 span .set_attribute ("heartbeat" , str (self .latest_heartbeat ))
203203 # This will cause it to load from the db
204- self . _merge_from ( Job . _fetch_from_db ( self , session ) )
204+ session . merge ( self )
205205 previous_heartbeat = self .latest_heartbeat
206206
207207 if self .state == JobState .RESTARTING :
@@ -217,17 +217,19 @@ def heartbeat(
217217 if span .is_recording ():
218218 span .add_event (name = "sleep" , attributes = {"sleep_for" : sleep_for })
219219 sleep (sleep_for )
220-
221- job = Job ._update_heartbeat (job = self , session = session )
222- self ._merge_from (job )
223- time_since_last_heartbeat = (timezone .utcnow () - previous_heartbeat ).total_seconds ()
224- health_check_threshold_value = health_check_threshold (self .job_type , self .heartrate )
225- if time_since_last_heartbeat > health_check_threshold_value :
226- self .log .info ("Heartbeat recovered after %.2f seconds" , time_since_last_heartbeat )
227- # At this point, the DB has updated.
228- previous_heartbeat = self .latest_heartbeat
229-
230- heartbeat_callback (session )
220+ # Update last heartbeat time
221+ with create_session () as session :
222+ # Make the session aware of this object
223+ session .merge (self )
224+ self .latest_heartbeat = timezone .utcnow ()
225+ session .commit ()
226+ time_since_last_heartbeat = (timezone .utcnow () - previous_heartbeat ).total_seconds ()
227+ health_check_threshold_value = health_check_threshold (self .job_type , self .heartrate )
228+ if time_since_last_heartbeat > health_check_threshold_value :
229+ self .log .info ("Heartbeat recovered after %.2f seconds" , time_since_last_heartbeat )
230+ # At this point, the DB has updated.
231+ previous_heartbeat = self .latest_heartbeat
232+ heartbeat_callback (session )
231233 self .log .debug ("[heartbeat]" )
232234 self .heartbeat_failed = False
233235 except OperationalError :
@@ -260,36 +262,23 @@ def prepare_for_execution(self, session: Session = NEW_SESSION):
260262 Stats .incr (self .__class__ .__name__ .lower () + "_start" , 1 , 1 )
261263 self .state = JobState .RUNNING
262264 self .start_date = timezone .utcnow ()
263- self ._merge_from (Job ._add_to_db (job = self , session = session ))
265+ session .add (self )
266+ session .commit ()
264267 make_transient (self )
265268
266269 @provide_session
267270 def complete_execution (self , session : Session = NEW_SESSION ):
268271 get_listener_manager ().hook .before_stopping (component = self )
269272 self .end_date = timezone .utcnow ()
270- Job ._update_in_db (job = self , session = session )
273+ session .merge (self )
274+ session .commit ()
271275 Stats .incr (self .__class__ .__name__ .lower () + "_end" , 1 , 1 )
272276
273277 @provide_session
274- def most_recent_job (self , session : Session = NEW_SESSION ) -> Job | JobPydantic | None :
278+ def most_recent_job (self , session : Session = NEW_SESSION ) -> Job | None :
275279 """Return the most recent job of this type, if any, based on last heartbeat received."""
276280 return most_recent_job (self .job_type , session = session )
277281
278- def _merge_from (self , job : Job | JobPydantic | None ):
279- if job is None :
280- self .log .error ("Job is empty: %s" , self .id )
281- return
282- self .id = job .id
283- self .dag_id = job .dag_id
284- self .state = job .state
285- self .job_type = job .job_type
286- self .start_date = job .start_date
287- self .end_date = job .end_date
288- self .latest_heartbeat = job .latest_heartbeat
289- self .executor_class = job .executor_class
290- self .hostname = job .hostname
291- self .unixname = job .unixname
292-
293282 @staticmethod
294283 def _heartrate (job_type : str ) -> float :
295284 if job_type == "TriggererJob" :
@@ -312,74 +301,9 @@ def _is_alive(
312301 and (timezone .utcnow () - latest_heartbeat ).total_seconds () < health_check_threshold_value
313302 )
314303
315- @staticmethod
316- @internal_api_call
317- @provide_session
318- def _kill (job_id : str , session : Session = NEW_SESSION ) -> Job | JobPydantic :
319- job = session .scalar (select (Job ).where (Job .id == job_id ).limit (1 ))
320- job .end_date = timezone .utcnow ()
321- session .merge (job )
322- session .commit ()
323- return job
324-
325- @staticmethod
326- @internal_api_call
327- @provide_session
328- @retry_db_transaction
329- def _fetch_from_db (job : Job | JobPydantic , session : Session = NEW_SESSION ) -> Job | JobPydantic | None :
330- if isinstance (job , Job ):
331- # not Internal API
332- session .merge (job )
333- return job
334- # Internal API,
335- return session .scalar (select (Job ).where (Job .id == job .id ).limit (1 ))
336-
337- @staticmethod
338- @internal_api_call
339- @provide_session
340- def _add_to_db (job : Job | JobPydantic , session : Session = NEW_SESSION ) -> Job | JobPydantic :
341- if isinstance (job , JobPydantic ):
342- orm_job = Job ()
343- orm_job ._merge_from (job )
344- else :
345- orm_job = job
346- session .add (orm_job )
347- session .commit ()
348- return orm_job
349-
350- @staticmethod
351- @internal_api_call
352- @provide_session
353- def _update_in_db (job : Job | JobPydantic , session : Session = NEW_SESSION ):
354- if isinstance (job , Job ):
355- # not Internal API
356- session .merge (job )
357- session .commit ()
358- # Internal API.
359- orm_job : Job | None = session .scalar (select (Job ).where (Job .id == job .id ).limit (1 ))
360- if orm_job is None :
361- return
362- orm_job ._merge_from (job )
363- session .merge (orm_job )
364- session .commit ()
365-
366- @staticmethod
367- @internal_api_call
368- @provide_session
369- @retry_db_transaction
370- def _update_heartbeat (job : Job | JobPydantic , session : Session = NEW_SESSION ) -> Job | JobPydantic :
371- orm_job : Job | None = session .scalar (select (Job ).where (Job .id == job .id ).limit (1 ))
372- if orm_job is None :
373- return job
374- orm_job .latest_heartbeat = timezone .utcnow ()
375- session .merge (orm_job )
376- session .commit ()
377- return orm_job
378-
379304
380- @internal_api_call
381305@provide_session
382- def most_recent_job (job_type : str , session : Session = NEW_SESSION ) -> Job | JobPydantic | None :
306+ def most_recent_job (job_type : str , session : Session = NEW_SESSION ) -> Job | None :
383307 """
384308 Return the most recent job of this type, if any, based on last heartbeat received.
385309
@@ -434,7 +358,7 @@ def execute_job(job: Job, execute_callable: Callable[[], int | None]) -> int | N
434358 which happens in the "complete_execution" step (which again can be executed locally in case of
435359 database operations or over the Internal API call.
436360
437- :param job: Job to execute - it can be either DB job or it's Pydantic serialized version . It does
361+ :param job: Job to execute - DB job. It does
438362 not really matter, because except of running the heartbeat and state setting,
439363 the runner should not modify the job state.
440364
0 commit comments